From fc6caa0df05a13016b6e69c80feacec78dddf9f2 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 4 Jan 2022 19:36:21 -0500 Subject: [PATCH] Easier method of adding kwargs to bridged in aiserver.py --- aiserver.py | 66 ++++++++++++++++++++++++++++++----------------------- 1 file changed, 38 insertions(+), 28 deletions(-) diff --git a/aiserver.py b/aiserver.py index 1e5556a9..e902aa36 100644 --- a/aiserver.py +++ b/aiserver.py @@ -21,7 +21,7 @@ import zipfile import packaging import contextlib import traceback -from typing import Any, Union, Dict, Set, List +from typing import Any, Callable, Union, Dict, Set, List import requests import html @@ -1064,9 +1064,17 @@ if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")): def lua_log_format_name(name): return f"[{name}]" if type(name) is str else "CORE" +_bridged = {} +def bridged_kwarg(name=None): + def _bridged_kwarg(f: Callable): + _bridged[name if name is not None else f.__name__[4:] if f.__name__[:4] == "lua_" else f.__name__] = f + return f + return _bridged_kwarg + #==================================================================# # Event triggered when a userscript is loaded #==================================================================# +@bridged_kwarg() def load_callback(filename, modulename): print(colors.GREEN + f"Loading Userscript [{modulename}] <{filename}>" + colors.END) @@ -1110,6 +1118,7 @@ def load_lua_scripts(): #==================================================================# # Print message that originates from the userscript with the given name #==================================================================# +@bridged_kwarg() def lua_print(msg): if(vars.lua_logname != vars.lua_koboldbridge.logging_name): vars.lua_logname = vars.lua_koboldbridge.logging_name @@ -1119,6 +1128,7 @@ def lua_print(msg): #==================================================================# # Print warning that originates from the userscript with the given name #==================================================================# +@bridged_kwarg() def lua_warn(msg): if(vars.lua_logname != vars.lua_koboldbridge.logging_name): vars.lua_logname = vars.lua_koboldbridge.logging_name @@ -1128,6 +1138,7 @@ def lua_warn(msg): #==================================================================# # Decode tokens into a string using current tokenizer #==================================================================# +@bridged_kwarg() def lua_decode(tokens): tokens = list(tokens.values()) assert type(tokens) is list @@ -1140,6 +1151,7 @@ def lua_decode(tokens): #==================================================================# # Encode string into list of token IDs using current tokenizer #==================================================================# +@bridged_kwarg() def lua_encode(string): assert type(string) is str if("tokenizer" not in globals()): @@ -1152,6 +1164,7 @@ def lua_encode(string): # Computes context given a submission, Lua array of entry UIDs and a Lua array # of folder UIDs #==================================================================# +@bridged_kwarg() def lua_compute_context(submission, entries, folders, kwargs): assert type(submission) is str if(kwargs is None): @@ -1190,6 +1203,7 @@ def lua_compute_context(submission, entries, folders, kwargs): #==================================================================# # Get property of a world info entry given its UID and property name #==================================================================# +@bridged_kwarg() def lua_get_attr(uid, k): assert type(uid) is int and type(k) is str if(uid in vars.worldinfo_u and k in ( @@ -1208,6 +1222,7 @@ def lua_get_attr(uid, k): #==================================================================# # Set property of a world info entry given its UID, property name and new value #==================================================================# +@bridged_kwarg() def lua_set_attr(uid, k, v): assert type(uid) is int and type(k) is str assert uid in vars.worldinfo_u and k in ( @@ -1227,6 +1242,7 @@ def lua_set_attr(uid, k, v): #==================================================================# # Get property of a world info folder given its UID and property name #==================================================================# +@bridged_kwarg() def lua_folder_get_attr(uid, k): assert type(uid) is int and type(k) is str if(uid in vars.wifolders_d and k in ( @@ -1237,6 +1253,7 @@ def lua_folder_get_attr(uid, k): #==================================================================# # Set property of a world info folder given its UID, property name and new value #==================================================================# +@bridged_kwarg() def lua_folder_set_attr(uid, k, v): assert type(uid) is int and type(k) is str assert uid in vars.wifolders_d and k in ( @@ -1251,12 +1268,14 @@ def lua_folder_set_attr(uid, k, v): #==================================================================# # Get the "Amount to Generate" #==================================================================# +@bridged_kwarg() def lua_get_genamt(): return vars.genamt #==================================================================# # Set the "Amount to Generate" #==================================================================# +@bridged_kwarg() def lua_set_genamt(genamt): assert vars.lua_koboldbridge.userstate != "genmod" and type(genamt) in (int, float) and genamt >= 0 print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set genamt to {int(genamt)}" + colors.END) @@ -1265,12 +1284,14 @@ def lua_set_genamt(genamt): #==================================================================# # Get the "Gens Per Action" #==================================================================# +@bridged_kwarg() def lua_get_numseqs(): return vars.numseqs #==================================================================# # Set the "Gens Per Action" #==================================================================# +@bridged_kwarg() def lua_set_numseqs(numseqs): assert type(numseqs) in (int, float) and numseqs >= 1 print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set numseqs to {int(numseqs)}" + colors.END) @@ -1279,6 +1300,7 @@ def lua_set_numseqs(numseqs): #==================================================================# # Check if a setting exists with the given name #==================================================================# +@bridged_kwarg() def lua_has_setting(setting): return setting in ( "anotedepth", @@ -1326,6 +1348,7 @@ def lua_has_setting(setting): #==================================================================# # Return the setting with the given name if it exists #==================================================================# +@bridged_kwarg() def lua_get_setting(setting): if(setting in ("settemp", "temp")): return vars.temp if(setting in ("settopp", "topp", "top_p")): return vars.top_p @@ -1350,6 +1373,7 @@ def lua_get_setting(setting): #==================================================================# # Set the setting with the given name if it exists #==================================================================# +@bridged_kwarg() def lua_set_setting(setting, v): actual_type = type(lua_get_setting(setting)) assert v is not None and (actual_type is type(v) or (actual_type is int and type(v) is float)) @@ -1380,12 +1404,14 @@ def lua_set_setting(setting, v): #==================================================================# # Get contents of memory #==================================================================# +@bridged_kwarg() def lua_get_memory(): return vars.memory #==================================================================# # Set contents of memory #==================================================================# +@bridged_kwarg() def lua_set_memory(m): assert type(m) is str vars.memory = m @@ -1393,12 +1419,14 @@ def lua_set_memory(m): #==================================================================# # Get contents of author's note #==================================================================# +@bridged_kwarg() def lua_get_authorsnote(): return vars.authornote #==================================================================# # Set contents of author's note #==================================================================# +@bridged_kwarg() def lua_set_authorsnote(m): assert type(m) is str vars.authornote = m @@ -1406,12 +1434,14 @@ def lua_set_authorsnote(m): #==================================================================# # Get contents of author's note template #==================================================================# +@bridged_kwarg() def lua_get_authorsnotetemplate(): return vars.authornotetemplate #==================================================================# # Set contents of author's note template #==================================================================# +@bridged_kwarg() def lua_set_authorsnotetemplate(m): assert type(m) is str vars.authornotetemplate = m @@ -1419,6 +1449,7 @@ def lua_set_authorsnotetemplate(m): #==================================================================# # Save settings and send them to client #==================================================================# +@bridged_kwarg() def lua_resend_settings(): settingschanged() refresh_settings() @@ -1426,6 +1457,7 @@ def lua_resend_settings(): #==================================================================# # Set story chunk text and delete the chunk if the new chunk is empty #==================================================================# +@bridged_kwarg() def lua_set_chunk(k, v): assert type(k) in (int, None) and type(v) is str assert k >= 0 @@ -1458,6 +1490,7 @@ def lua_set_chunk(k, v): #==================================================================# # Get model type as "gpt-2-xl", "gpt-neo-2.7B", etc. #==================================================================# +@bridged_kwarg() def lua_get_modeltype(): if(vars.noai): return "readonly" @@ -1486,6 +1519,7 @@ def lua_get_modeltype(): #==================================================================# # Get model backend as "transformers" or "mtj" #==================================================================# +@bridged_kwarg() def lua_get_modelbackend(): if(vars.noai): return "readonly" @@ -1498,6 +1532,7 @@ def lua_get_modelbackend(): #==================================================================# # Check whether model is loaded from a custom path #==================================================================# +@bridged_kwarg() def lua_is_custommodel(): return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ") @@ -1558,35 +1593,10 @@ bridged = { "userscript_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"), "config_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"), "lib_paths": vars.lua_state.table(os.path.join(os.path.dirname(os.path.realpath(__file__)), "lualibs"), os.path.join(os.path.dirname(os.path.realpath(__file__)), "extern", "lualibs")), - "load_callback": load_callback, - "print": lua_print, - "warn": lua_warn, - "decode": lua_decode, - "encode": lua_encode, - "get_attr": lua_get_attr, - "set_attr": lua_set_attr, - "folder_get_attr": lua_folder_get_attr, - "folder_set_attr": lua_folder_set_attr, - "get_genamt": lua_get_genamt, - "set_genamt": lua_set_genamt, - "get_memory": lua_get_memory, - "set_memory": lua_set_memory, - "get_authorsnote": lua_get_authorsnote, - "set_authorsnote": lua_set_authorsnote, - "get_authorsnote": lua_get_authorsnotetemplate, - "set_authorsnote": lua_set_authorsnotetemplate, - "compute_context": lua_compute_context, - "get_numseqs": lua_get_numseqs, - "set_numseqs": lua_set_numseqs, - "has_setting": lua_has_setting, - "get_setting": lua_get_setting, - "set_setting": lua_set_setting, - "set_chunk": lua_set_chunk, - "get_modeltype": lua_get_modeltype, - "get_modelbackend": lua_get_modelbackend, - "is_custommodel": lua_is_custommodel, "vars": vars, } +for kwarg in _bridged: + bridged[kwarg] = _bridged[kwarg] try: vars.lua_kobold, vars.lua_koboldcore, vars.lua_koboldbridge = vars.lua_state.globals().dofile(os.path.join(os.path.dirname(os.path.realpath(__file__)), "bridge.lua"))( vars.lua_state.globals().python,