From a3a52dc9c35df054ba41d7db23a24baf3268f6ca Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 12 Apr 2022 15:59:05 -0400 Subject: [PATCH 1/3] Add support for changing soft prompt from userscripts --- aiserver.py | 47 +++++++++++++++++++++++++++++++++++++++++++++-- bridge.lua | 23 +++++++++++++++++++++++ fileops.py | 5 ++++- 3 files changed, 72 insertions(+), 3 deletions(-) diff --git a/aiserver.py b/aiserver.py index db0e8872..5d519e80 100644 --- a/aiserver.py +++ b/aiserver.py @@ -217,6 +217,7 @@ class vars: abort = False # Whether or not generation was aborted by clicking on the submit button during generation compiling = False # If using a TPU Colab, this will be set to True when the TPU backend starts compiling and then set to False again checking = False # Whether or not we are actively checking to see if TPU backend is compiling or not + sp_changed = False # This gets set to True whenever a userscript changes the soft prompt so that check_for_sp_change() can alert the browser that the soft prompt has changed spfilename = "" # Filename of soft prompt to load, or an empty string if not using a soft prompt userscripts = [] # List of userscripts to load last_userscripts = [] # List of previous userscript filenames from the previous time userscripts were send via usstatitems @@ -697,13 +698,29 @@ def loadsettings(): #==================================================================# # Load a soft prompt from a file #==================================================================# + +def check_for_sp_change(): + while(True): + time.sleep(0.1) + if(vars.sp_changed): + with app.app_context(): + emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, namespace=None, broadcast=True) + vars.sp_changed = False + def spRequest(filename): + if(not vars.allowsp): + raise RuntimeError("Soft prompts are not supported by your current model/backend") + + old_filename = vars.spfilename + vars.spfilename = "" settingschanged() if(len(filename) == 0): vars.sp = None vars.sp_length = 0 + if(old_filename != filename): + vars.sp_changed = True return global np @@ -711,7 +728,8 @@ def spRequest(filename): import numpy as np z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim) - assert isinstance(z, zipfile.ZipFile) + if not isinstance(z, zipfile.ZipFile): + raise RuntimeError(f"{repr(filename)} is not a valid soft prompt file") with z.open('meta.json') as f: vars.spmeta = json.load(f) z.close() @@ -747,6 +765,8 @@ def spRequest(filename): vars.spfilename = filename settingschanged() + if(old_filename != filename): + vars.sp_changed = True #==================================================================# # Startup @@ -1067,6 +1087,7 @@ from flask_socketio import SocketIO, emit app = Flask(__name__, root_path=os.getcwd()) app.config['SECRET KEY'] = 'secret!' socketio = SocketIO(app, async_method="eventlet") +socketio.start_background_task(check_for_sp_change) print("{0}OK!{1}".format(colors.GREEN, colors.END)) # Start transformers and create pipeline @@ -2197,6 +2218,29 @@ def lua_get_modelbackend(): def lua_is_custommodel(): return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") +#==================================================================# +# Return the filename (as a string) of the current soft prompt, or +# None if no soft prompt is loaded +#==================================================================# +@bridged_kwarg() +def lua_get_spfilename(): + return vars.spfilename.strip() or None + +#==================================================================# +# When called with a string as argument, sets the current soft prompt; +# when called with None as argument, uses no soft prompt. +# Returns True if soft prompt changed, False otherwise. +#==================================================================# +@bridged_kwarg() +def lua_set_spfilename(filename: Union[str, None]): + if(filename is None): + filename = "" + filename = str(filename).strip() + changed = lua_get_spfilename() != filename + assert all(q not in filename for q in ("/", "\\")) + spRequest(filename) + return changed + #==================================================================# # #==================================================================# @@ -2611,7 +2655,6 @@ def get_message(msg): loadRequest(fileops.storypath(vars.loadselect)) elif(msg['cmd'] == 'sprequest'): spRequest(vars.spselect) - emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, broadcast=True) elif(msg['cmd'] == 'deletestory'): deletesave(msg['data']) elif(msg['cmd'] == 'renamestory'): diff --git a/bridge.lua b/bridge.lua index 796bc33f..1a15c792 100644 --- a/bridge.lua +++ b/bridge.lua @@ -1050,11 +1050,34 @@ return function(_python, _bridged) return elseif not bridged.vars.gamestarted and v == "" then error("`KoboldLib.submission` must not be set to the empty string when the story is empty") + return end bridged.vars.submission = v end + --========================================================================== + -- Userscript API: Soft prompt + --========================================================================== + + ---@param t KoboldLib + ---@return string? + function KoboldLib_getters.spfilename(t) + return bridged.get_spfilename() + end + + ---@param t KoboldLib + ---@param v string? + function KoboldLib_setters.spfilename(t, v) + if v:find("/") or v:find("\\") then + error("Cannot set `KoboldLib.spfilename` to a string that contains slashes") + end + if bridged.set_spfilename(v) then + maybe_require_regeneration() + end + end + + --========================================================================== -- Userscript API: Model information --========================================================================== diff --git a/fileops.py b/fileops.py index cd269931..c303764e 100644 --- a/fileops.py +++ b/fileops.py @@ -122,7 +122,10 @@ def checksp(filename: str, model_dimension: int) -> Tuple[Union[zipfile.ZipFile, shape, fortran_order, dtype = np.lib.format._read_array_header(f, version) assert len(shape) == 2 except: - z.close() + try: + z.close() + except UnboundLocalError: + pass return 1, None, None, None, None if dtype not in ('V2', np.float16, np.float32): z.close() From efea584d84448872a29626814809b0479a1db23b Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Thu, 14 Apr 2022 14:58:11 -0400 Subject: [PATCH 2/3] Update API documentation --- userscripts/api_documentation.html | 10 ++++++++++ userscripts/api_documentation.md | 16 ++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/userscripts/api_documentation.html b/userscripts/api_documentation.html index f504f821..86230366 100644 --- a/userscripts/api_documentation.html +++ b/userscripts/api_documentation.html @@ -56,6 +56,7 @@
  • kobold.num_outputs
  • kobold.outputs
  • kobold.settings
  • +
  • kobold.spfilename
  • kobold.story +

    kobold.spfilename

    +

    Readable from: anywhere
    +Writable from: anywhere

    +
    field kobold.spfilename: string?
    +
    +

    The name of the soft prompt file to use (as a string), including the file extension. If not using a soft prompt, this is nil instead.

    +

    You can also set the soft prompt to use by setting this to a string or nil.

    +

    Modifying this field from inside of a generation modifier triggers a regeneration, which means that the context is recomputed after modification and generation begins again with the new context and previously generated tokens. This incurs a small performance penalty and should not be performed in excess.

    kobold.story

    Readable from: anywhere
    Writable from: nowhere

    diff --git a/userscripts/api_documentation.md b/userscripts/api_documentation.md index 1c37c644..fda69670 100644 --- a/userscripts/api_documentation.md +++ b/userscripts/api_documentation.md @@ -29,6 +29,7 @@ global kobold: KoboldLib * `kobold.num_outputs` * `kobold.outputs` * `kobold.settings` +* `kobold.spfilename` * `kobold.story` * `kobold.submission` * `kobold.worldinfo` @@ -372,6 +373,21 @@ Modifying certain fields from inside of a generation modifier triggers a regener * `kobold.settings.setwidepth` (World Info Depth) * `kobold.settings.setuseprompt` (Always Use Prompt) +# kobold.spfilename + +***Readable from:*** anywhere +***Writable from:*** anywhere + +```lua +field kobold.spfilename: string? +``` + +The name of the soft prompt file to use (as a string), including the file extension. If not using a soft prompt, this is `nil` instead. + +You can also set the soft prompt to use by setting this to a string or `nil`. + +Modifying this field from inside of a generation modifier triggers a regeneration, which means that the context is recomputed after modification and generation begins again with the new context and previously generated tokens. This incurs a small performance penalty and should not be performed in excess. + # kobold.story ***Readable from:*** anywhere From dcdd0263fc8ef125568d56949663f7bc7647061d Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Thu, 14 Apr 2022 15:00:41 -0400 Subject: [PATCH 3/3] Increment `API_VERSION` in bridge.lua --- bridge.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridge.lua b/bridge.lua index 1a15c792..ed0941c6 100644 --- a/bridge.lua +++ b/bridge.lua @@ -165,7 +165,7 @@ return function(_python, _bridged) ---@field num_outputs integer ---@field feedback string ---@field is_config_file_open boolean - local kobold = setmetatable({API_VERSION = 1.0}, metawrapper) + local kobold = setmetatable({API_VERSION = 1.1}, metawrapper) local KoboldLib_mt = setmetatable({}, metawrapper) local KoboldLib_getters = setmetatable({}, metawrapper) local KoboldLib_setters = setmetatable({}, metawrapper)