diff --git a/aiserver.py b/aiserver.py index 141ab656..ebdf7fa8 100644 --- a/aiserver.py +++ b/aiserver.py @@ -219,6 +219,7 @@ class vars: abort = False # Whether or not generation was aborted by clicking on the submit button during generation compiling = False # If using a TPU Colab, this will be set to True when the TPU backend starts compiling and then set to False again checking = False # Whether or not we are actively checking to see if TPU backend is compiling or not + sp_changed = False # This gets set to True whenever a userscript changes the soft prompt so that check_for_sp_change() can alert the browser that the soft prompt has changed spfilename = "" # Filename of soft prompt to load, or an empty string if not using a soft prompt userscripts = [] # List of userscripts to load last_userscripts = [] # List of previous userscript filenames from the previous time userscripts were send via usstatitems @@ -699,13 +700,29 @@ def loadsettings(): #==================================================================# # Load a soft prompt from a file #==================================================================# + +def check_for_sp_change(): + while(True): + time.sleep(0.1) + if(vars.sp_changed): + with app.app_context(): + emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, namespace=None, broadcast=True) + vars.sp_changed = False + def spRequest(filename): + if(not vars.allowsp): + raise RuntimeError("Soft prompts are not supported by your current model/backend") + + old_filename = vars.spfilename + vars.spfilename = "" settingschanged() if(len(filename) == 0): vars.sp = None vars.sp_length = 0 + if(old_filename != filename): + vars.sp_changed = True return global np @@ -713,7 +730,8 @@ def spRequest(filename): import numpy as np z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim) - assert isinstance(z, zipfile.ZipFile) + if not isinstance(z, zipfile.ZipFile): + raise RuntimeError(f"{repr(filename)} is not a valid soft prompt file") with z.open('meta.json') as f: vars.spmeta = json.load(f) z.close() @@ -749,6 +767,8 @@ def spRequest(filename): vars.spfilename = filename settingschanged() + if(old_filename != filename): + vars.sp_changed = True #==================================================================# # Startup @@ -1069,6 +1089,7 @@ from flask_socketio import SocketIO, emit app = Flask(__name__, root_path=os.getcwd()) app.config['SECRET KEY'] = 'secret!' socketio = SocketIO(app, async_method="eventlet") +socketio.start_background_task(check_for_sp_change) print("{0}OK!{1}".format(colors.GREEN, colors.END)) # Start transformers and create pipeline @@ -2199,6 +2220,29 @@ def lua_get_modelbackend(): def lua_is_custommodel(): return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") +#==================================================================# +# Return the filename (as a string) of the current soft prompt, or +# None if no soft prompt is loaded +#==================================================================# +@bridged_kwarg() +def lua_get_spfilename(): + return vars.spfilename.strip() or None + +#==================================================================# +# When called with a string as argument, sets the current soft prompt; +# when called with None as argument, uses no soft prompt. +# Returns True if soft prompt changed, False otherwise. +#==================================================================# +@bridged_kwarg() +def lua_set_spfilename(filename: Union[str, None]): + if(filename is None): + filename = "" + filename = str(filename).strip() + changed = lua_get_spfilename() != filename + assert all(q not in filename for q in ("/", "\\")) + spRequest(filename) + return changed + #==================================================================# # #==================================================================# @@ -2613,7 +2657,6 @@ def get_message(msg): loadRequest(fileops.storypath(vars.loadselect)) elif(msg['cmd'] == 'sprequest'): spRequest(vars.spselect) - emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, broadcast=True) elif(msg['cmd'] == 'deletestory'): deletesave(msg['data']) elif(msg['cmd'] == 'renamestory'): diff --git a/bridge.lua b/bridge.lua index 796bc33f..ed0941c6 100644 --- a/bridge.lua +++ b/bridge.lua @@ -165,7 +165,7 @@ return function(_python, _bridged) ---@field num_outputs integer ---@field feedback string ---@field is_config_file_open boolean - local kobold = setmetatable({API_VERSION = 1.0}, metawrapper) + local kobold = setmetatable({API_VERSION = 1.1}, metawrapper) local KoboldLib_mt = setmetatable({}, metawrapper) local KoboldLib_getters = setmetatable({}, metawrapper) local KoboldLib_setters = setmetatable({}, metawrapper) @@ -1050,11 +1050,34 @@ return function(_python, _bridged) return elseif not bridged.vars.gamestarted and v == "" then error("`KoboldLib.submission` must not be set to the empty string when the story is empty") + return end bridged.vars.submission = v end + --========================================================================== + -- Userscript API: Soft prompt + --========================================================================== + + ---@param t KoboldLib + ---@return string? + function KoboldLib_getters.spfilename(t) + return bridged.get_spfilename() + end + + ---@param t KoboldLib + ---@param v string? + function KoboldLib_setters.spfilename(t, v) + if v:find("/") or v:find("\\") then + error("Cannot set `KoboldLib.spfilename` to a string that contains slashes") + end + if bridged.set_spfilename(v) then + maybe_require_regeneration() + end + end + + --========================================================================== -- Userscript API: Model information --========================================================================== diff --git a/fileops.py b/fileops.py index cd269931..c303764e 100644 --- a/fileops.py +++ b/fileops.py @@ -122,7 +122,10 @@ def checksp(filename: str, model_dimension: int) -> Tuple[Union[zipfile.ZipFile, shape, fortran_order, dtype = np.lib.format._read_array_header(f, version) assert len(shape) == 2 except: - z.close() + try: + z.close() + except UnboundLocalError: + pass return 1, None, None, None, None if dtype not in ('V2', np.float16, np.float32): z.close() diff --git a/userscripts/api_documentation.html b/userscripts/api_documentation.html index f504f821..86230366 100644 --- a/userscripts/api_documentation.html +++ b/userscripts/api_documentation.html @@ -56,6 +56,7 @@
kobold.num_outputs
kobold.outputs
kobold.settings
kobold.spfilename
kobold.story
kobold.submission
kobold.worldinfo
kobold.settings.setwidepth
(World Info Depth)kobold.settings.setuseprompt
(Always Use Prompt)Readable from: anywhere
+Writable from: anywhere
field kobold.spfilename: string?
+
+The name of the soft prompt file to use (as a string), including the file extension. If not using a soft prompt, this is nil
instead.
You can also set the soft prompt to use by setting this to a string or nil
.
Modifying this field from inside of a generation modifier triggers a regeneration, which means that the context is recomputed after modification and generation begins again with the new context and previously generated tokens. This incurs a small performance penalty and should not be performed in excess.
Readable from: anywhere
Writable from: nowhere