From e289a0d3608bb504458402c4144a8321ec771f79 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Sat, 11 Dec 2021 12:45:45 -0500 Subject: [PATCH] Connect bridge.lua to aiserver.py Also enables the use of input modifiers and output modifiers, but not generation modifiers. --- aiserver.py | 277 ++++++++++++++++++++++++++++++++++++++++++++++++++-- bridge.lua | 35 ++++--- 2 files changed, 292 insertions(+), 20 deletions(-) diff --git a/aiserver.py b/aiserver.py index d6bcc09f..6249bb4e 100644 --- a/aiserver.py +++ b/aiserver.py @@ -24,6 +24,8 @@ import argparse import sys import gc +import lupa + # KoboldAI import fileops import gensettings @@ -31,6 +33,11 @@ from utils import debounce import utils import structures + +if lupa.LUA_VERSION[:2] != (5, 4): + print(f"Please install lupa==1.10. You have lupa {lupa.__version__}.", file=sys.stderr) + + #==================================================================# # Variables & Storage #==================================================================# @@ -91,6 +98,10 @@ class vars: wifolders_d = {} # Dictionary of World Info folder UID-info pairs wifolders_l = [] # List of World Info folder UIDs wifolders_u = {} # Dictionary of pairs of folder UID - list of WI UID + lua_state = None # Lua state of the Lua scripting system + lua_koboldbridge = None # `koboldbridge` from bridge.lua + lua_kobold = None # `kobold` from` bridge.lua + lua_koboldcore = None # `koboldcore` from bridge.lua # badwords = [] # Array of str/chr values that should be removed from output badwordsids = [[13460], [6880], [50256], [42496], [4613], [17414], [22039], [16410], [27], [29], [38430], [37922], [15913], [24618], [28725], [58], [47175], [36937], [26700], [12878], [16471], [37981], [5218], [29795], [13412], [45160], [3693], [49778], [4211], [20598], [36475], [33409], [44167], [32406], [29847], [29342], [42669], [685], [25787], [7359], [3784], [5320], [33994], [33490], [34516], [43734], [17635], [24293], [9959], [23785], [21737], [28401], [18161], [26358], [32509], [1279], [38155], [18189], [26894], [6927], [14610], [23834], [11037], [14631], [26933], [46904], [22330], [25915], [47934], [38214], [1875], [14692], [41832], [13163], [25970], [29565], [44926], [19841], [37250], [49029], [9609], [44438], [16791], [17816], [30109], [41888], [47527], [42924], [23984], [49074], [33717], [31161], [49082], [30138], [31175], [12240], [14804], [7131], [26076], [33250], [3556], [38381], [36338], [32756], [46581], [17912], [49146]] # Tokenized array of badwords used to prevent AI artifacting deletewi = -1 # Temporary storage for index to delete @@ -549,7 +560,7 @@ print("{0}OK!{1}".format(colors.GREEN, colors.END)) if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): if(not vars.noai): print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END)) - from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM + from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM try: from transformers import GPTJModel except: @@ -751,7 +762,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme else: model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **maybe_low_cpu_mem_usage()) vars.modeldim = get_hidden_size_from_model(model) - tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/") + tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/") # Is CUDA available? If so, use GPU, otherwise fall back to CPU if(vars.hascuda): if(vars.usegpu): @@ -769,7 +780,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme js = json.load(model_config) with(maybe_use_float16()): model = GPT2LMHeadModel.from_pretrained(vars.custmodpth, cache_dir="cache/", **maybe_low_cpu_mem_usage()) - tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/") + tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/") vars.modeldim = get_hidden_size_from_model(model) # Is CUDA available? If so, use GPU, otherwise fall back to CPU if(vars.hascuda and vars.usegpu): @@ -780,7 +791,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme # If base HuggingFace model was chosen else: # Is CUDA available? If so, use GPU, otherwise fall back to CPU - tokenizer = GPT2Tokenizer.from_pretrained(vars.model, cache_dir="cache/") + tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache/") if(vars.hascuda): if(vars.usegpu): with(maybe_use_float16()): @@ -810,14 +821,18 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme # vars.badwordsids.append([vocab[key]]) print("{0}OK! {1} pipeline created!{2}".format(colors.GREEN, vars.model, colors.END)) + + else: + from transformers import GPT2TokenizerFast + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") else: # If we're running Colab or OAI, we still need a tokenizer. if(vars.model == "Colab"): - from transformers import GPT2Tokenizer - tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B") + from transformers import GPT2TokenizerFast + tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-2.7B") elif(vars.model == "OAI"): - from transformers import GPT2Tokenizer - tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + from transformers import GPT2TokenizerFast + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") # Load the TPU backend if requested elif(vars.model == "TPUMeshTransformerGPTJ"): print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END)) @@ -875,6 +890,239 @@ def download(): save.headers.set('Content-Disposition', 'attachment', filename='%s.json' % filename) return(save) + +#============================ LUA API =============================# + +#==================================================================# +# Event triggered when a userscript is loaded +#==================================================================# +def load_callback(filename): + print(colors.PURPLE + f"Loading Userscript [USERPLACEHOLDER] <{filename}>" + colors.END) + +#==================================================================# +# Load all Lua scripts +#==================================================================# +def load_lua_scripts(): + print(colors.PURPLE + "Loading Core Script [COREPLACEHOLDER] " + colors.END) + try: + vars.lua_koboldbridge.obliterate_multiverse() + vars.lua_koboldbridge.load_corescript("default.lua") + vars.lua_koboldbridge.load_userscripts([], [], []) + except lupa.LuaError as e: + print(e, file=sys.stderr) + exit(1) + +#==================================================================# +# Decode tokens into a string using current tokenizer +#==================================================================# +def lua_decode(tokens): + assert type(tokens) is list + return tokenizer.decode(tokens) + +#==================================================================# +# Encode string into list of token IDs using current tokenizer +#==================================================================# +def lua_encode(string): + assert type(string) is str + return tokenizer.encode(string, max_length=int(4e9), truncation=True) + +#==================================================================# +# Get property of a world info entry given its UID and property name +#==================================================================# +def lua_get_attr(uid, k): + assert type(uid) is int and type(k) is str + if(uid in vars.worldinfo_u and k in ( + "key", + "keysecondary", + "content", + "comment", + "folder", + "num", + "selective", + "constant", + "uid", + )): + return vars.worldinfo_u[uid][k] + +#==================================================================# +# Set property of a world info entry given its UID, property name and new value +#==================================================================# +def lua_set_attr(uid, k, v): + assert type(uid) is int and type(k) is str + assert uid in vars.worldinfo_u and k in ( + "key", + "keysecondary", + "content", + "comment", + "selective", + "constant", + ) + if(type(vars.worldinfo_u[uid][k]) is int and type(v) is float): + v = int(v) + assert type(vars.worldinfo_u[uid][k]) is type(v) + vars.worldinfo_u[uid][k] = v + +#==================================================================# +# Get property of a world info folder given its UID and property name +#==================================================================# +def lua_folder_get_attr(uid, k): + assert type(uid) is int and type(k) is str + if(uid in vars.wifolders_d and k in ( + "comment", + )): + return vars.wifolders_d[uid][k] + +#==================================================================# +# Set property of a world info folder given its UID, property name and new value +#==================================================================# +def lua_folder_set_attr(uid, k, v): + assert type(uid) is int and type(k) is str + assert uid in vars.wifolders_d and k in ( + "comment", + ) + if(type(vars.wifolders_d[uid][k]) is int and type(v) is float): + v = int(v) + assert type(vars.wifolders_d[uid][k]) is type(v) + vars.wifolders_d[uid][k] = v + +#==================================================================# +# Get the "Amount to Generate" +#==================================================================# +def lua_get_gen_len(): + return vars.genamt + +#==================================================================# +# Set the "Amount to Generate" +#==================================================================# +def lua_set_gen_len(genamt): + assert vars.lua_koboldbridge.userstate != "genmod" and type(genamt) in (int, float) and genamt >= 0 + vars.genamt = int(genamt) + +#==================================================================# +# Get the "Gens Per Action" +#==================================================================# +def lua_get_numseqs(): + return vars.numseqs + +#==================================================================# +# Set the "Gens Per Action" +#==================================================================# +def lua_set_numseqs(numseqs): + assert type(numseqs) in (int, float) and numseqs >= 1 + vars.genamt = int(numseqs) + +#==================================================================# +# Check if a setting exists with the given name +#==================================================================# +def lua_has_setting(setting): + return setting in ( + "settemp", + "settopp", + "settopk", + "settfs", + "setreppen", + "setoutput", + "settknmax", + "anotedepth", + "setwidepth", + "setuseprompt", + "setadventure", + "frmttriminc", + "frmtrmblln", + "frmtrmspch", + "frmtadsnsp", + "singleline", + ) + +#==================================================================# +# Return the setting with the given name if it exists +#==================================================================# +def lua_get_setting(setting): + if(setting == "settemp"): return vars.temp + if(setting == "settopp"): return vars.top_p + if(setting == "settopk"): return vars.top_k + if(setting == "settfs"): return vars.tfs + if(setting == "setreppen"): return vars.rep_pen + if(setting == "settknmax"): return vars.max_length + if(setting == "anotedepth"): return vars.andepth + if(setting == "setwidepth"): return vars.widepth + if(setting == "setuseprompt"): return vars.useprompt + if(setting == "setadventure"): return vars.adventure + if(setting == "frmttriminc"): return vars.formatoptns["frmttriminc"] + if(setting == "frmtrmblln"): return vars.formatoptns["frmttriminc"] + if(setting == "frmtrmspch"): return vars.formatoptns["frmttriminc"] + if(setting == "frmtadsnsp"): return vars.formatoptns["frmttriminc"] + if(setting == "singleline"): return vars.formatoptns["frmttriminc"] + +#==================================================================# +# Set the setting with the given name if it exists +#==================================================================# +def lua_set_setting(setting, v): + actual_type = type(lua_get_setting(setting)) + assert v is not None and (actual_type is type(v) or (actual_type is int and type(v) is float)) + v = actual_type(v) + print(colors.PURPLE + f"[USERPLACEHOLDER] set {setting} to {v}" + colors.END) + if(setting == "setadventure" and v): + vars.actionmode = 1 + get_message({'cmd': setting, 'data': v}) + +#==================================================================# +# Get contents of memory +#==================================================================# +def lua_get_memory(): + return vars.memory + +#==================================================================# +# Set contents of memory +#==================================================================# +def lua_set_memory(m): + assert type(m) is str + vars.memory = m + +#==================================================================# +# Lua runtime startup +#==================================================================# + +print(colors.PURPLE + "Initializing Lua Bridge... " + colors.END, end="") + +# Set up Lua state +vars.lua_state = lupa.LuaRuntime(unpack_returned_tuples=True) + +# Load bridge.lua +bridged = { + "corescript_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts", "corescripts"), + "userscript_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts", "userscripts"), + "lib_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "extern", "lualibs"), + "load_callback": load_callback, + "decode": lua_decode, + "encode": lua_encode, + "get_attr": lua_get_attr, + "set_attr": lua_set_attr, + "get_gen_len": lua_get_gen_len, + "set_gen_len": lua_set_gen_len, + "get_memory": lua_get_memory, + "set_memory": lua_set_memory, + "get_numseqs": lua_get_numseqs, + "set_numseqs": lua_set_numseqs, + "get_setting": lua_get_setting, + "set_setting": lua_set_setting, + "vars": vars, +} +try: + vars.lua_kobold, vars.lua_koboldcore, vars.lua_koboldbridge = vars.lua_state.globals().dofile(os.path.join(os.path.dirname(os.path.realpath(__file__)), "bridge.lua"))( + vars.lua_state.globals().python, + bridged, + ) +except lupa.LuaError as e: + print(colors.RED + "ERROR!" + colors.END) + print(e, file=sys.stderr) + exit(1) +print(colors.GREEN + "OK!" + colors.END) + +# Load scripts +load_lua_scripts() + + #============================ METHODS =============================# #==================================================================# @@ -1314,6 +1562,9 @@ def actionsubmit(data, actionmode=0, force_submit=False): vars.recentedit = False vars.actionmode = actionmode + # Run the core script's input modifier + vars.lua_koboldbridge.execute_inmod() + # "Action" mode if(actionmode == 1): data = data.strip().lstrip('>') @@ -1339,6 +1590,7 @@ def actionsubmit(data, actionmode=0, force_submit=False): calcsubmit(data) # Run the first action through the generator emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True) else: + vars.lua_koboldbridge.execute_outmod() refresh_story() set_aibusy(0) emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True) @@ -1360,6 +1612,7 @@ def actionsubmit(data, actionmode=0, force_submit=False): calcsubmit(data) emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True) else: + vars.lua_koboldbridge.execute_outmod() set_aibusy(0) emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True) @@ -1700,6 +1953,8 @@ def generate(txt, minimum, maximum, found_entries=None): # Need to manually strip and decode tokens if we're not using a pipeline #already_generated = -(len(gen_in[0]) - len(tokens)) genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout] + + vars.lua_koboldbridge.execute_outmod() if(len(genout) == 1): genresult(genout[0]["generated_text"]) @@ -1800,6 +2055,8 @@ def sendtocolab(txt, min, max): else: genout = js["seqs"] + vars.lua_koboldbridge.execute_outmod() + if(len(genout) == 1): genresult(genout[0]) else: @@ -1881,6 +2138,8 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): genout = [{"generated_text": txt} for txt in genout] + vars.lua_koboldbridge.execute_outmod() + if(len(genout) == 1): genresult(genout[0]["generated_text"]) else: @@ -2459,6 +2718,7 @@ def ikrequest(txt): # Deal with the response if(req.status_code == 200): genout = req.json()["data"]["text"] + vars.lua_koboldbridge.execute_outmod() print("{0}{1}{2}".format(colors.CYAN, genout, colors.END)) vars.actions.append(genout) update_story_chunk('last') @@ -2509,6 +2769,7 @@ def oairequest(txt, min, max): # Deal with the response if(req.status_code == 200): genout = req.json()["choices"][0]["text"] + vars.lua_koboldbridge.execute_outmod() print("{0}{1}{2}".format(colors.CYAN, genout, colors.END)) vars.actions.append(genout) update_story_chunk('last') diff --git a/bridge.lua b/bridge.lua index fe1f8a4a..780e5788 100644 --- a/bridge.lua +++ b/bridge.lua @@ -252,7 +252,7 @@ return function(_python, _bridged) ---@return boolean function KoboldWorldInfoEntry:is_valid() - return bridged.worldinfo_u.get(self.uid) ~= nil + return bridged.vars.worldinfo_u.get(self.uid) ~= nil end ---@return string @@ -340,7 +340,7 @@ return function(_python, _bridged) if not check_validity(self) or type(u) ~= "number" then return end - local query = bridged.worldinfo_u.get(u) + local query = bridged.vars.worldinfo_u.get(u) if query == nil or (rawget(self, "_name") == "KoboldWorldInfoFolder" and self.uid ~= query.get("folder")) then return end @@ -373,7 +373,7 @@ return function(_python, _bridged) ---@return boolean function KoboldWorldInfoFolder:is_valid() - return bridged.wifolders_d.get(self.uid) ~= nil + return bridged.vars.wifolders_d.get(self.uid) ~= nil end ---@param t KoboldWorldInfoFolder @@ -382,7 +382,7 @@ return function(_python, _bridged) if not check_validity(t) then return 0 end - return _python.builtins.len(bridged.worldinfo_u.get(t.uid)) + return _python.builtins.len(bridged.vars.worldinfo_u.get(t.uid)) end KoboldWorldInfoFolder_mt._kobold_next = KoboldWorldInfoEntry_mt._kobold_next @@ -397,9 +397,9 @@ return function(_python, _bridged) elseif rawget(t, "_name") == "KoboldWorldInfoFolder" and k == "uid" then return rawget(t, "_uid") elseif rawget(t, "_name") == "KoboldWorldInfoFolder" and k == "comment" then - return bridged.wifolders_d.get(t.uid).__getitem__("comment") + return bridged.folder_get_attr(t.uid, k) elseif type(k) == "number" then - local query = rawget(t, "_name") == "KoboldWorldInfoFolder" and bridged.wifolders_u.get(t.uid) or bridged.worldinfo + local query = rawget(t, "_name") == "KoboldWorldInfoFolder" and bridged.vars.wifolders_u.get(t.uid) or bridged.vars.worldinfo k = math.tointeger(k) if k == nil or k < 1 or k > _python.builtins.len(query) then return @@ -424,7 +424,7 @@ return function(_python, _bridged) error("`"..rawget(t, "_name").."."..k.."` must be a string; you attempted to set it to a "..type(v)) return end - bridged.safe_setitem(bridged.wifolders_d.get(t.uid), "comment", v) + bridged.folder_set_attr(t.uid, k, v) return t else return rawset(t, k, v) @@ -450,7 +450,7 @@ return function(_python, _bridged) if not check_validity(self) or type(u) ~= "number" then return end - local query = bridged.wifolders_d.get(u) + local query = bridged.vars.wifolders_d.get(u) if query == nil then return end @@ -480,11 +480,11 @@ return function(_python, _bridged) ---@param t KoboldWorldInfoFolderSelector ---@return KoboldWorldInfoFolder|nil function KoboldWorldInfoFolderSelector_mt.__index(t, k) - if not check_validity(t) or type(k) ~= "number" or math.tointeger(k) == nil or k < 1 or k > _python.builtins.len(bridged.wifolders_l) then + if not check_validity(t) or type(k) ~= "number" or math.tointeger(k) == nil or k < 1 or k > _python.builtins.len(bridged.vars.wifolders_l) then return end local folder = deepcopy(KoboldWorldInfoFolder) - rawset(folder, "_uid", bridged.wifolders_l.__getitem__(k)) + rawset(folder, "_uid", bridged.vars.wifolders_l.__getitem__(k)) return folder end @@ -523,7 +523,7 @@ return function(_python, _bridged) if not check_validity(t) then return 0 end - return _python.builtins.len(bridged.worldinfo) + return _python.builtins.len(bridged.vars.worldinfo) end KoboldWorldInfo_mt._kobold_next = KoboldWorldInfoEntry_mt._kobold_next @@ -577,6 +577,8 @@ return function(_python, _bridged) end if k == "gen_len" then return bridged.get_gen_len() + elseif k == "numseqs" then + return bridged.get_numseqs() elseif bridged.has_setting(k) then return bridged.get_setting(k), true else @@ -588,7 +590,16 @@ return function(_python, _bridged) function KoboldSettings_mt.__newindex(t, k, v) if k == "gen_len" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 0 then bridged.set_gen_len(v) + elseif k == "numseqs" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 1 then + if koboldbridge.userstate == "genmod" then + error("Cannot set numseqs from a generation modifier") + return + end + bridged.set_numseqs(v) elseif type(k) == "string" and bridged.has_setting(k) and type(v) == type(bridged.get_setting(k)) then + if k == "settknmax" or k == "anotedepth" or k == "setwidepth" or k == "setuseprompt" then + maybe_save_genmod_comparison_context() + end return bridged.set_setting(k, v) end return t @@ -838,7 +849,7 @@ return function(_python, _bridged) local old_package_loaded = package.loaded local old_package_searchers = package.searchers ---@param modname string - ---@param env? table + ---@param env table ---@param search_path? string ---@return any, string|nil local function requirex(modname, env, search_path)