From 63bb76b073535948eb69e7c59dc3bd8d881c6fa2 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 4 Jan 2022 14:13:36 -0500 Subject: [PATCH 1/9] Make sure `vars.wifolders_u` is set up properly on loading a save --- aiserver.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/aiserver.py b/aiserver.py index 80de124e..2bd43be0 100644 --- a/aiserver.py +++ b/aiserver.py @@ -3867,6 +3867,8 @@ def loadRequest(loadpath, filename=None): break vars.worldinfo_u[uid] = vars.worldinfo[-1] vars.worldinfo[-1]["uid"] = uid + if(vars.worldinfo[-1]["folder"] is not None): + vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1]) num += 1 for uid in vars.wifolders_l + [None]: @@ -3877,6 +3879,8 @@ def loadRequest(loadpath, filename=None): break vars.worldinfo_u[uid] = vars.worldinfo[-1] vars.worldinfo[-1]["uid"] = uid + if(vars.worldinfo[-1]["folder"] is not None): + vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1]) stablesortwi() vars.worldinfo_i = [wi for wi in vars.worldinfo if wi["init"]] @@ -4072,6 +4076,8 @@ def importgame(): break vars.worldinfo_u[uid] = vars.worldinfo[-1] vars.worldinfo[-1]["uid"] = uid + if(vars.worldinfo[-1]["folder"]) is not None: + vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1]) num += 1 for uid in vars.wifolders_l + [None]: @@ -4082,6 +4088,8 @@ def importgame(): break vars.worldinfo_u[uid] = vars.worldinfo[-1] vars.worldinfo[-1]["uid"] = uid + if(vars.worldinfo[-1]["folder"] is not None): + vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1]) stablesortwi() vars.worldinfo_i = [wi for wi in vars.worldinfo if wi["init"]] @@ -4151,6 +4159,8 @@ def importAidgRequest(id): break vars.worldinfo_u[uid] = vars.worldinfo[-1] vars.worldinfo[-1]["uid"] = uid + if(vars.worldinfo[-1]["folder"]) is not None: + vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1]) num += 1 for uid in vars.wifolders_l + [None]: @@ -4161,6 +4171,8 @@ def importAidgRequest(id): break vars.worldinfo_u[uid] = vars.worldinfo[-1] vars.worldinfo[-1]["uid"] = uid + if(vars.worldinfo[-1]["folder"] is not None): + vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1]) stablesortwi() vars.worldinfo_i = [wi for wi in vars.worldinfo if wi["init"]] @@ -4210,6 +4222,8 @@ def wiimportrequest(): break vars.worldinfo_u[uid] = vars.worldinfo[-1] vars.worldinfo[-1]["uid"] = uid + if(vars.worldinfo[-1]["folder"]) is not None: + vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1]) num += 1 for uid in [None]: vars.worldinfo.append({"key": "", "keysecondary": "", "content": "", "comment": "", "folder": uid, "num": None, "init": False, "selective": False, "constant": False, "uid": None}) @@ -4219,6 +4233,8 @@ def wiimportrequest(): break vars.worldinfo_u[uid] = vars.worldinfo[-1] vars.worldinfo[-1]["uid"] = uid + if(vars.worldinfo[-1]["folder"] is not None): + vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1]) print("{0}".format(vars.worldinfo[0])) From f46ebd235953fce7ed4a5c4dcb90dc8c7f3af36d Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 4 Jan 2022 14:18:58 -0500 Subject: [PATCH 2/9] Always pass 1.1 as repetition penalty to generator The `dynamic_processor_wrap` makes it so that the repetition penalty is read directly from `vars`, but this only works if the initial repetition sent to `generator` is not equal to 1. So we are now forcing the initial repetition penalty to be something other than 1. --- aiserver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiserver.py b/aiserver.py index 2bd43be0..cdb7a661 100644 --- a/aiserver.py +++ b/aiserver.py @@ -2560,7 +2560,7 @@ def _generate(txt, minimum, maximum, found_entries): do_sample=True, min_length=minimum, max_length=int(2e9), - repetition_penalty=vars.rep_pen, + repetition_penalty=1.1, bad_words_ids=vars.badwordsids, use_cache=True, num_return_sequences=numseqs From e20452ddd880330bb257d395d7f8721432972486 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 4 Jan 2022 14:40:10 -0500 Subject: [PATCH 3/9] Retrying random story generation now also remembers memory --- aiserver.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/aiserver.py b/aiserver.py index cdb7a661..3eefa16d 100644 --- a/aiserver.py +++ b/aiserver.py @@ -163,6 +163,7 @@ class vars: genseqs = [] # Temporary storage for generated sequences recentback = False # Whether Back button was recently used without Submitting or Retrying after recentrng = None # If a new random game was recently generated without Submitting after, this is the topic used (as a string), otherwise this is None + recentrngm = None # If a new random game was recently generated without Submitting after, this is the memory used (as a string), otherwise this is None useprompt = False # Whether to send the full prompt with every submit action breakmodel = False # For GPU users, whether to use both system RAM and VRAM to conserve VRAM while offering speedup compared to CPU-only bmsupported = False # Whether the breakmodel option is supported (GPT-Neo/GPT-J only, currently) @@ -1645,7 +1646,7 @@ def get_message(msg): vars.chatname = msg['chatname'] settingschanged() emit('from_server', {'cmd': 'setchatname', 'data': vars.chatname}, broadcast=True) - vars.recentrng = None + vars.recentrng = vars.recentrngm = None actionsubmit(msg['data'], actionmode=msg['actionmode']) elif(vars.mode == "edit"): editsubmit(msg['data']) @@ -2130,7 +2131,7 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, set_aibusy(1) if(disable_recentrng): - vars.recentrng = None + vars.recentrng = vars.recentrngm = None vars.recentback = False vars.recentedit = False @@ -2272,7 +2273,7 @@ def actionretry(data): if(vars.aibusy): return if(vars.recentrng is not None): - randomGameRequest(vars.recentrng) + randomGameRequest(vars.recentrng, memory=vars.recentrngm) return # Remove last action if possible and resubmit if(vars.gamestarted if vars.useprompt else len(vars.actions) > 0): @@ -4282,6 +4283,7 @@ def randomGameRequest(topic, memory=""): newGameRequest() return vars.recentrng = topic + vars.recentrngm = memory newGameRequest() _memory = memory if(len(memory) > 0): From 2fc0bdfcba5028729a81bf5155ee1418adb19a85 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 4 Jan 2022 14:41:31 -0500 Subject: [PATCH 4/9] Correct a typo in `restorePrompt()` --- static/application.js | 2 +- templates/index.html | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/static/application.js b/static/application.js index 4d9ab676..f97333d6 100644 --- a/static/application.js +++ b/static/application.js @@ -1511,7 +1511,7 @@ function restorePrompt() { if(shadow_text.length && shadow_text[0].firstChild && (shadow_text[0].firstChild.nodeType === 3 || shadow_text[0].firstChild.tagName === "BR")) { detected = true; ref = shadow_text; - } else if(game_text.length && game_text[0].firstChild && game_text[0].firstChild.nodeType === 3 || game_text[0].firstChild.tagName === "BR") { + } else if(game_text.length && game_text[0].firstChild && (game_text[0].firstChild.nodeType === 3 || game_text[0].firstChild.tagName === "BR")) { detected = true; ref = game_text; } diff --git a/templates/index.html b/templates/index.html index 41a81293..08357979 100644 --- a/templates/index.html +++ b/templates/index.html @@ -17,7 +17,7 @@ - + From aa86c6001c52746445575303b4298a48ea623591 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 4 Jan 2022 14:43:37 -0500 Subject: [PATCH 5/9] `--breakmodel_gpublocks` should handle -1 properly now --- aiserver.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/aiserver.py b/aiserver.py index 3eefa16d..8843b030 100644 --- a/aiserver.py +++ b/aiserver.py @@ -272,6 +272,13 @@ def device_config(model): try: breakmodel.gpu_blocks = list(map(int, args.breakmodel_gpulayers.split(','))) assert len(breakmodel.gpu_blocks) <= torch.cuda.device_count() + s = n_layers + for i in range(len(breakmodel.gpu_blocks)): + if(breakmodel.gpu_blocks[i] <= -1): + breakmodel.gpu_blocks[i] = s + break + else: + s -= breakmodel.gpu_blocks[i] assert sum(breakmodel.gpu_blocks) <= n_layers n_layers -= sum(breakmodel.gpu_blocks) except: From 6edc6387f4d583e13879372931aaa4f41a254820 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 4 Jan 2022 17:11:14 -0500 Subject: [PATCH 6/9] Accept command line arguments in `KOBOLDAI_ARGS` environment var So that you can use gunicorn or whatever with command-line arguments by passing the arguments in an environment variable. --- aiserver.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/aiserver.py b/aiserver.py index 8843b030..67d150e0 100644 --- a/aiserver.py +++ b/aiserver.py @@ -15,8 +15,6 @@ from eventlet import tpool from os import path, getcwd import re -import tkinter as tk -from tkinter import messagebox import json import collections import zipfile @@ -388,7 +386,12 @@ parser.add_argument("--override_delete", action='store_true', help="Deleting sto parser.add_argument("--override_rename", action='store_true', help="Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise.") parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.") -args = parser.parse_args() +args: argparse.Namespace = None +if(os.environ.get("KOBOLDAI_ARGS") is not None): + import shlex + args = parser.parse_args(shlex.split(os.environ["KOBOLDAI_ARGS"])) +else: + args = parser.parse_args() vars.model = args.model; if args.remote: @@ -4308,6 +4311,7 @@ loadsettings() # Final startup commands to launch Flask app #==================================================================# if __name__ == "__main__": + print("{0}\nStarting webserver...{1}".format(colors.GREEN, colors.END)) # Start Flask/SocketIO (Blocking, so this must be last method!) @@ -4321,12 +4325,15 @@ if __name__ == "__main__": cloudflare = _run_cloudflared(5000) with open('cloudflare.log', 'w') as cloudflarelog: cloudflarelog.write("KoboldAI has finished loading and is available at the following link : " + cloudflare) - print("\n" + format(colors.GREEN) + "KoboldAI has finished loading and is available at the following link : " + cloudflare + format(colors.END)) + print(format(colors.GREEN) + "KoboldAI has finished loading and is available at the following link : " + cloudflare + format(colors.END)) vars.serverstarted = True socketio.run(app, host='0.0.0.0', port=5000) else: import webbrowser webbrowser.open_new('http://localhost:5000') - print("{0}\nServer started!\nYou may now connect with a browser at http://127.0.0.1:5000/{1}".format(colors.GREEN, colors.END)) + print("{0}Server started!\nYou may now connect with a browser at http://127.0.0.1:5000/{1}".format(colors.GREEN, colors.END)) vars.serverstarted = True socketio.run(app, port=5000) + +else: + print("{0}\nServer started in WSGI mode!{1}".format(colors.GREEN, colors.END)) From fbf506207419df9b63884131af0c462ac0f90b91 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 4 Jan 2022 19:26:59 -0500 Subject: [PATCH 7/9] Add option to `compute_context()` to not scan story --- aiserver.py | 26 +++++++++++++++++----- bridge.lua | 10 +++++---- userscripts/kaipreset_location_scanner.lua | 4 ++-- 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/aiserver.py b/aiserver.py index 67d150e0..1e5556a9 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1152,8 +1152,10 @@ def lua_encode(string): # Computes context given a submission, Lua array of entry UIDs and a Lua array # of folder UIDs #==================================================================# -def lua_compute_context(submission, entries, folders): +def lua_compute_context(submission, entries, folders, kwargs): assert type(submission) is str + if(kwargs is None): + kwargs = vars.lua_state.table() actions = vars._actions if vars.lua_koboldbridge.userstate == "genmod" else vars.actions allowed_entries = None allowed_folders = None @@ -1169,8 +1171,20 @@ def lua_compute_context(submission, entries, folders): while(folders[i] is not None): allowed_folders.add(int(folders[i])) i += 1 - winfo, mem, anotetxt, _ = calcsubmitbudgetheader(submission, allowed_entries=allowed_entries, allowed_folders=allowed_folders, force_use_txt=True) - txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions) + winfo, mem, anotetxt, _ = calcsubmitbudgetheader( + submission, + allowed_entries=allowed_entries, + allowed_folders=allowed_folders, + force_use_txt=True, + scan_story=kwargs["scan_story"] if kwargs["scan_story"] != None else True, + ) + txt, _, _ = calcsubmitbudget( + len(actions), + winfo, + mem, + anotetxt, + actions, + ) return tokenizer.decode(txt) #==================================================================# @@ -3370,7 +3384,7 @@ def deletewifolder(uid): #==================================================================# # Look for WI keys in text to generator #==================================================================# -def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_txt=False): +def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_txt=False, scan_story=True): original_txt = txt # Dont go any further if WI is empty @@ -3381,7 +3395,7 @@ def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_tx ln = len(vars.actions) # Don't bother calculating action history if widepth is 0 - if(vars.widepth > 0): + if(vars.widepth > 0 and scan_story): depth = vars.widepth # If this is not a continue, add 1 to widepth since submitted # text is already in action history @ -1 @@ -3423,7 +3437,7 @@ def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_tx found_entries.add(id(wi)) continue - if(wi["key"] != ""): + if(len(wi["key"].strip()) > 0 and (not wi.get("selective", False) or len(wi.get("keysecondary", "").strip()) > 0)): # Split comma-separated keys keys = wi["key"].split(",") keys_secondary = wi.get("keysecondary", "").split(",") diff --git a/bridge.lua b/bridge.lua index 460cbf8d..d5adbd01 100644 --- a/bridge.lua +++ b/bridge.lua @@ -383,8 +383,9 @@ return function(_python, _bridged) end ---@param submission? string + ---@param kwargs? table ---@return string - function KoboldWorldInfoEntry:compute_context(submission) + function KoboldWorldInfoEntry:compute_context(submission, kwargs) if not check_validity(self) then return "" elseif submission == nil then @@ -393,7 +394,7 @@ return function(_python, _bridged) error("`compute_context` takes a string or nil as argument #1, but got a " .. type(submission)) return "" end - return bridged.compute_context(submission, {self.uid}, nil) + return bridged.compute_context(submission, {self.uid}, nil, kwargs) end ---@generic K @@ -484,8 +485,9 @@ return function(_python, _bridged) ---@param submission? string ---@param entries? KoboldWorldInfoEntry|table + ---@param kwargs? table ---@return string - function KoboldWorldInfoFolder:compute_context(submission, entries) + function KoboldWorldInfoFolder:compute_context(submission, entries, kwargs) if not check_validity(self) then return "" elseif submission == nil then @@ -513,7 +515,7 @@ return function(_python, _bridged) if self.name == "KoboldWorldInfoFolder" then folders = {rawget(self, "_uid")} end - return bridged.compute_context(submission, _entries, folders) + return bridged.compute_context(submission, _entries, folders, kwargs) end ---@return boolean diff --git a/userscripts/kaipreset_location_scanner.lua b/userscripts/kaipreset_location_scanner.lua index 4348e322..06f7e1ca 100644 --- a/userscripts/kaipreset_location_scanner.lua +++ b/userscripts/kaipreset_location_scanner.lua @@ -26,7 +26,7 @@ local example_config = [[;-- Location scanner ;-- Usage instructions: ;-- ;-- 1. Create a world info folder with name containing the string -;-- "<||ls||>" (without the double quotes). The comment can be anything as +;-- "<||ls||>" (without the double quotes). The name can be anything as ;-- long as it contains that inside it somewhere -- for example, you could ;-- set the name to "Locations <||ls||>". ;-- @@ -124,7 +124,7 @@ function userscript.inmod() key = e.key, keysecondary = e.keysecondary, } - e.constant = e.constant or (not repeated and e:compute_context("") ~= e:compute_context(location)) + e.constant = e.constant or (not repeated and e:compute_context("", {scan_story=false}) ~= e:compute_context(location, {scan_story=false})) e.key = "" e.keysecondary = "" end From fc6caa0df05a13016b6e69c80feacec78dddf9f2 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 4 Jan 2022 19:36:21 -0500 Subject: [PATCH 8/9] Easier method of adding kwargs to bridged in aiserver.py --- aiserver.py | 66 ++++++++++++++++++++++++++++++----------------------- 1 file changed, 38 insertions(+), 28 deletions(-) diff --git a/aiserver.py b/aiserver.py index 1e5556a9..e902aa36 100644 --- a/aiserver.py +++ b/aiserver.py @@ -21,7 +21,7 @@ import zipfile import packaging import contextlib import traceback -from typing import Any, Union, Dict, Set, List +from typing import Any, Callable, Union, Dict, Set, List import requests import html @@ -1064,9 +1064,17 @@ if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")): def lua_log_format_name(name): return f"[{name}]" if type(name) is str else "CORE" +_bridged = {} +def bridged_kwarg(name=None): + def _bridged_kwarg(f: Callable): + _bridged[name if name is not None else f.__name__[4:] if f.__name__[:4] == "lua_" else f.__name__] = f + return f + return _bridged_kwarg + #==================================================================# # Event triggered when a userscript is loaded #==================================================================# +@bridged_kwarg() def load_callback(filename, modulename): print(colors.GREEN + f"Loading Userscript [{modulename}] <{filename}>" + colors.END) @@ -1110,6 +1118,7 @@ def load_lua_scripts(): #==================================================================# # Print message that originates from the userscript with the given name #==================================================================# +@bridged_kwarg() def lua_print(msg): if(vars.lua_logname != vars.lua_koboldbridge.logging_name): vars.lua_logname = vars.lua_koboldbridge.logging_name @@ -1119,6 +1128,7 @@ def lua_print(msg): #==================================================================# # Print warning that originates from the userscript with the given name #==================================================================# +@bridged_kwarg() def lua_warn(msg): if(vars.lua_logname != vars.lua_koboldbridge.logging_name): vars.lua_logname = vars.lua_koboldbridge.logging_name @@ -1128,6 +1138,7 @@ def lua_warn(msg): #==================================================================# # Decode tokens into a string using current tokenizer #==================================================================# +@bridged_kwarg() def lua_decode(tokens): tokens = list(tokens.values()) assert type(tokens) is list @@ -1140,6 +1151,7 @@ def lua_decode(tokens): #==================================================================# # Encode string into list of token IDs using current tokenizer #==================================================================# +@bridged_kwarg() def lua_encode(string): assert type(string) is str if("tokenizer" not in globals()): @@ -1152,6 +1164,7 @@ def lua_encode(string): # Computes context given a submission, Lua array of entry UIDs and a Lua array # of folder UIDs #==================================================================# +@bridged_kwarg() def lua_compute_context(submission, entries, folders, kwargs): assert type(submission) is str if(kwargs is None): @@ -1190,6 +1203,7 @@ def lua_compute_context(submission, entries, folders, kwargs): #==================================================================# # Get property of a world info entry given its UID and property name #==================================================================# +@bridged_kwarg() def lua_get_attr(uid, k): assert type(uid) is int and type(k) is str if(uid in vars.worldinfo_u and k in ( @@ -1208,6 +1222,7 @@ def lua_get_attr(uid, k): #==================================================================# # Set property of a world info entry given its UID, property name and new value #==================================================================# +@bridged_kwarg() def lua_set_attr(uid, k, v): assert type(uid) is int and type(k) is str assert uid in vars.worldinfo_u and k in ( @@ -1227,6 +1242,7 @@ def lua_set_attr(uid, k, v): #==================================================================# # Get property of a world info folder given its UID and property name #==================================================================# +@bridged_kwarg() def lua_folder_get_attr(uid, k): assert type(uid) is int and type(k) is str if(uid in vars.wifolders_d and k in ( @@ -1237,6 +1253,7 @@ def lua_folder_get_attr(uid, k): #==================================================================# # Set property of a world info folder given its UID, property name and new value #==================================================================# +@bridged_kwarg() def lua_folder_set_attr(uid, k, v): assert type(uid) is int and type(k) is str assert uid in vars.wifolders_d and k in ( @@ -1251,12 +1268,14 @@ def lua_folder_set_attr(uid, k, v): #==================================================================# # Get the "Amount to Generate" #==================================================================# +@bridged_kwarg() def lua_get_genamt(): return vars.genamt #==================================================================# # Set the "Amount to Generate" #==================================================================# +@bridged_kwarg() def lua_set_genamt(genamt): assert vars.lua_koboldbridge.userstate != "genmod" and type(genamt) in (int, float) and genamt >= 0 print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set genamt to {int(genamt)}" + colors.END) @@ -1265,12 +1284,14 @@ def lua_set_genamt(genamt): #==================================================================# # Get the "Gens Per Action" #==================================================================# +@bridged_kwarg() def lua_get_numseqs(): return vars.numseqs #==================================================================# # Set the "Gens Per Action" #==================================================================# +@bridged_kwarg() def lua_set_numseqs(numseqs): assert type(numseqs) in (int, float) and numseqs >= 1 print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set numseqs to {int(numseqs)}" + colors.END) @@ -1279,6 +1300,7 @@ def lua_set_numseqs(numseqs): #==================================================================# # Check if a setting exists with the given name #==================================================================# +@bridged_kwarg() def lua_has_setting(setting): return setting in ( "anotedepth", @@ -1326,6 +1348,7 @@ def lua_has_setting(setting): #==================================================================# # Return the setting with the given name if it exists #==================================================================# +@bridged_kwarg() def lua_get_setting(setting): if(setting in ("settemp", "temp")): return vars.temp if(setting in ("settopp", "topp", "top_p")): return vars.top_p @@ -1350,6 +1373,7 @@ def lua_get_setting(setting): #==================================================================# # Set the setting with the given name if it exists #==================================================================# +@bridged_kwarg() def lua_set_setting(setting, v): actual_type = type(lua_get_setting(setting)) assert v is not None and (actual_type is type(v) or (actual_type is int and type(v) is float)) @@ -1380,12 +1404,14 @@ def lua_set_setting(setting, v): #==================================================================# # Get contents of memory #==================================================================# +@bridged_kwarg() def lua_get_memory(): return vars.memory #==================================================================# # Set contents of memory #==================================================================# +@bridged_kwarg() def lua_set_memory(m): assert type(m) is str vars.memory = m @@ -1393,12 +1419,14 @@ def lua_set_memory(m): #==================================================================# # Get contents of author's note #==================================================================# +@bridged_kwarg() def lua_get_authorsnote(): return vars.authornote #==================================================================# # Set contents of author's note #==================================================================# +@bridged_kwarg() def lua_set_authorsnote(m): assert type(m) is str vars.authornote = m @@ -1406,12 +1434,14 @@ def lua_set_authorsnote(m): #==================================================================# # Get contents of author's note template #==================================================================# +@bridged_kwarg() def lua_get_authorsnotetemplate(): return vars.authornotetemplate #==================================================================# # Set contents of author's note template #==================================================================# +@bridged_kwarg() def lua_set_authorsnotetemplate(m): assert type(m) is str vars.authornotetemplate = m @@ -1419,6 +1449,7 @@ def lua_set_authorsnotetemplate(m): #==================================================================# # Save settings and send them to client #==================================================================# +@bridged_kwarg() def lua_resend_settings(): settingschanged() refresh_settings() @@ -1426,6 +1457,7 @@ def lua_resend_settings(): #==================================================================# # Set story chunk text and delete the chunk if the new chunk is empty #==================================================================# +@bridged_kwarg() def lua_set_chunk(k, v): assert type(k) in (int, None) and type(v) is str assert k >= 0 @@ -1458,6 +1490,7 @@ def lua_set_chunk(k, v): #==================================================================# # Get model type as "gpt-2-xl", "gpt-neo-2.7B", etc. #==================================================================# +@bridged_kwarg() def lua_get_modeltype(): if(vars.noai): return "readonly" @@ -1486,6 +1519,7 @@ def lua_get_modeltype(): #==================================================================# # Get model backend as "transformers" or "mtj" #==================================================================# +@bridged_kwarg() def lua_get_modelbackend(): if(vars.noai): return "readonly" @@ -1498,6 +1532,7 @@ def lua_get_modelbackend(): #==================================================================# # Check whether model is loaded from a custom path #==================================================================# +@bridged_kwarg() def lua_is_custommodel(): return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ") @@ -1558,35 +1593,10 @@ bridged = { "userscript_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"), "config_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"), "lib_paths": vars.lua_state.table(os.path.join(os.path.dirname(os.path.realpath(__file__)), "lualibs"), os.path.join(os.path.dirname(os.path.realpath(__file__)), "extern", "lualibs")), - "load_callback": load_callback, - "print": lua_print, - "warn": lua_warn, - "decode": lua_decode, - "encode": lua_encode, - "get_attr": lua_get_attr, - "set_attr": lua_set_attr, - "folder_get_attr": lua_folder_get_attr, - "folder_set_attr": lua_folder_set_attr, - "get_genamt": lua_get_genamt, - "set_genamt": lua_set_genamt, - "get_memory": lua_get_memory, - "set_memory": lua_set_memory, - "get_authorsnote": lua_get_authorsnote, - "set_authorsnote": lua_set_authorsnote, - "get_authorsnote": lua_get_authorsnotetemplate, - "set_authorsnote": lua_set_authorsnotetemplate, - "compute_context": lua_compute_context, - "get_numseqs": lua_get_numseqs, - "set_numseqs": lua_set_numseqs, - "has_setting": lua_has_setting, - "get_setting": lua_get_setting, - "set_setting": lua_set_setting, - "set_chunk": lua_set_chunk, - "get_modeltype": lua_get_modeltype, - "get_modelbackend": lua_get_modelbackend, - "is_custommodel": lua_is_custommodel, "vars": vars, } +for kwarg in _bridged: + bridged[kwarg] = _bridged[kwarg] try: vars.lua_kobold, vars.lua_koboldcore, vars.lua_koboldbridge = vars.lua_state.globals().dofile(os.path.join(os.path.dirname(os.path.realpath(__file__)), "bridge.lua"))( vars.lua_state.globals().python, From 01479c29eaf981e5585adb1e907e1d74fe0f573d Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 4 Jan 2022 20:48:34 -0500 Subject: [PATCH 9/9] Fix the type hint for `bridged_kwarg` decorator --- aiserver.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/aiserver.py b/aiserver.py index e902aa36..a2f3daf3 100644 --- a/aiserver.py +++ b/aiserver.py @@ -21,7 +21,7 @@ import zipfile import packaging import contextlib import traceback -from typing import Any, Callable, Union, Dict, Set, List +from typing import Any, Callable, TypeVar, Union, Dict, Set, List import requests import html @@ -1065,8 +1065,9 @@ def lua_log_format_name(name): return f"[{name}]" if type(name) is str else "CORE" _bridged = {} +F = TypeVar("F", bound=Callable) def bridged_kwarg(name=None): - def _bridged_kwarg(f: Callable): + def _bridged_kwarg(f: F): _bridged[name if name is not None else f.__name__[4:] if f.__name__[:4] == "lua_" else f.__name__] = f return f return _bridged_kwarg