mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge branch 'united' into patch
This commit is contained in:
29
aiserver.py
29
aiserver.py
@ -14,6 +14,7 @@ os.environ['EVENTLET_THREADPOOL_SIZE'] = '50'
|
||||
from eventlet import tpool
|
||||
|
||||
from os import path, getcwd
|
||||
import time
|
||||
import re
|
||||
import json
|
||||
import collections
|
||||
@ -128,6 +129,8 @@ class vars:
|
||||
lua_deleted = set() # Set of chunk numbers that were deleted from a Lua generation modifier
|
||||
generated_tkns = 0 # If using a backend that supports Lua generation modifiers, how many tokens have already been generated, otherwise 0
|
||||
abort = False # Whether or not generation was aborted by clicking on the submit button during generation
|
||||
compiling = False # If using a TPU Colab, this will be set to True when the TPU backend starts compiling and then set to False again
|
||||
checking = False # Whether or not we are actively checking to see if TPU backend is compiling or not
|
||||
spfilename = "" # Filename of soft prompt to load, or an empty string if not using a soft prompt
|
||||
userscripts = [] # List of userscripts to load
|
||||
last_userscripts = [] # List of previous userscript filenames from the previous time userscripts were send via usstatitems
|
||||
@ -639,7 +642,7 @@ log.setLevel(logging.ERROR)
|
||||
|
||||
# Start flask & SocketIO
|
||||
print("{0}Initializing Flask... {1}".format(colors.PURPLE, colors.END), end="")
|
||||
from flask import Flask, render_template, Response, request
|
||||
from flask import Flask, render_template, Response, request, copy_current_request_context
|
||||
from flask_socketio import SocketIO, emit
|
||||
app = Flask(__name__)
|
||||
app.config['SECRET KEY'] = 'secret!'
|
||||
@ -1052,6 +1055,13 @@ else:
|
||||
break
|
||||
return excluded_world_info, regeneration_required, halt
|
||||
|
||||
def tpumtjgenerate_compiling_callback() -> None:
|
||||
print(colors.GREEN + "TPU backend compilation triggered" + colors.END)
|
||||
vars.compiling = True
|
||||
|
||||
def tpumtjgenerate_stopped_compiling_callback() -> None:
|
||||
vars.compiling = False
|
||||
|
||||
# If we're running Colab or OAI, we still need a tokenizer.
|
||||
if(vars.model == "Colab"):
|
||||
from transformers import GPT2TokenizerFast
|
||||
@ -1066,6 +1076,8 @@ else:
|
||||
import tpu_mtj_backend
|
||||
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
|
||||
tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback
|
||||
tpu_mtj_backend.compiling_callback = tpumtjgenerate_compiling_callback
|
||||
tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback
|
||||
tpu_mtj_backend.load_model(vars.custmodpth)
|
||||
vars.allowsp = True
|
||||
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
||||
@ -1643,6 +1655,7 @@ def execute_genmod():
|
||||
vars.lua_koboldbridge.execute_genmod()
|
||||
|
||||
def execute_outmod():
|
||||
emit('from_server', {'cmd': 'hidemsg', 'data': ''}, broadcast=True)
|
||||
try:
|
||||
tpool.execute(vars.lua_koboldbridge.execute_outmod)
|
||||
except lupa.LuaError as e:
|
||||
@ -2259,6 +2272,18 @@ def settingschanged():
|
||||
#==================================================================#
|
||||
# Take input text from SocketIO and decide what to do with it
|
||||
#==================================================================#
|
||||
|
||||
def check_for_backend_compilation():
|
||||
if(vars.checking):
|
||||
return
|
||||
vars.checking = True
|
||||
for _ in range(31):
|
||||
time.sleep(0.06276680299820175)
|
||||
if(vars.compiling):
|
||||
emit('from_server', {'cmd': 'warnmsg', 'data': 'Compiling TPU backend—this usually takes 1–2 minutes...'}, broadcast=True)
|
||||
break
|
||||
vars.checking = False
|
||||
|
||||
def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False):
|
||||
# Ignore new submissions if the AI is currently busy
|
||||
if(vars.aibusy):
|
||||
@ -2972,6 +2997,8 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||
|
||||
global past
|
||||
|
||||
socketio.start_background_task(copy_current_request_context(check_for_backend_compilation))
|
||||
|
||||
if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)):
|
||||
|
||||
context = np.tile(np.uint32(txt), (vars.numseqs, 1))
|
||||
|
Reference in New Issue
Block a user