Easier method of adding kwargs to bridged in aiserver.py

This commit is contained in:
Gnome Ann 2022-01-04 19:36:21 -05:00
parent fbf5062074
commit fc6caa0df0
1 changed files with 38 additions and 28 deletions

View File

@ -21,7 +21,7 @@ import zipfile
import packaging import packaging
import contextlib import contextlib
import traceback import traceback
from typing import Any, Union, Dict, Set, List from typing import Any, Callable, Union, Dict, Set, List
import requests import requests
import html import html
@ -1064,9 +1064,17 @@ if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
def lua_log_format_name(name): def lua_log_format_name(name):
return f"[{name}]" if type(name) is str else "CORE" return f"[{name}]" if type(name) is str else "CORE"
_bridged = {}
def bridged_kwarg(name=None):
def _bridged_kwarg(f: Callable):
_bridged[name if name is not None else f.__name__[4:] if f.__name__[:4] == "lua_" else f.__name__] = f
return f
return _bridged_kwarg
#==================================================================# #==================================================================#
# Event triggered when a userscript is loaded # Event triggered when a userscript is loaded
#==================================================================# #==================================================================#
@bridged_kwarg()
def load_callback(filename, modulename): def load_callback(filename, modulename):
print(colors.GREEN + f"Loading Userscript [{modulename}] <{filename}>" + colors.END) print(colors.GREEN + f"Loading Userscript [{modulename}] <{filename}>" + colors.END)
@ -1110,6 +1118,7 @@ def load_lua_scripts():
#==================================================================# #==================================================================#
# Print message that originates from the userscript with the given name # Print message that originates from the userscript with the given name
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_print(msg): def lua_print(msg):
if(vars.lua_logname != vars.lua_koboldbridge.logging_name): if(vars.lua_logname != vars.lua_koboldbridge.logging_name):
vars.lua_logname = vars.lua_koboldbridge.logging_name vars.lua_logname = vars.lua_koboldbridge.logging_name
@ -1119,6 +1128,7 @@ def lua_print(msg):
#==================================================================# #==================================================================#
# Print warning that originates from the userscript with the given name # Print warning that originates from the userscript with the given name
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_warn(msg): def lua_warn(msg):
if(vars.lua_logname != vars.lua_koboldbridge.logging_name): if(vars.lua_logname != vars.lua_koboldbridge.logging_name):
vars.lua_logname = vars.lua_koboldbridge.logging_name vars.lua_logname = vars.lua_koboldbridge.logging_name
@ -1128,6 +1138,7 @@ def lua_warn(msg):
#==================================================================# #==================================================================#
# Decode tokens into a string using current tokenizer # Decode tokens into a string using current tokenizer
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_decode(tokens): def lua_decode(tokens):
tokens = list(tokens.values()) tokens = list(tokens.values())
assert type(tokens) is list assert type(tokens) is list
@ -1140,6 +1151,7 @@ def lua_decode(tokens):
#==================================================================# #==================================================================#
# Encode string into list of token IDs using current tokenizer # Encode string into list of token IDs using current tokenizer
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_encode(string): def lua_encode(string):
assert type(string) is str assert type(string) is str
if("tokenizer" not in globals()): if("tokenizer" not in globals()):
@ -1152,6 +1164,7 @@ def lua_encode(string):
# Computes context given a submission, Lua array of entry UIDs and a Lua array # Computes context given a submission, Lua array of entry UIDs and a Lua array
# of folder UIDs # of folder UIDs
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_compute_context(submission, entries, folders, kwargs): def lua_compute_context(submission, entries, folders, kwargs):
assert type(submission) is str assert type(submission) is str
if(kwargs is None): if(kwargs is None):
@ -1190,6 +1203,7 @@ def lua_compute_context(submission, entries, folders, kwargs):
#==================================================================# #==================================================================#
# Get property of a world info entry given its UID and property name # Get property of a world info entry given its UID and property name
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_get_attr(uid, k): def lua_get_attr(uid, k):
assert type(uid) is int and type(k) is str assert type(uid) is int and type(k) is str
if(uid in vars.worldinfo_u and k in ( if(uid in vars.worldinfo_u and k in (
@ -1208,6 +1222,7 @@ def lua_get_attr(uid, k):
#==================================================================# #==================================================================#
# Set property of a world info entry given its UID, property name and new value # Set property of a world info entry given its UID, property name and new value
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_set_attr(uid, k, v): def lua_set_attr(uid, k, v):
assert type(uid) is int and type(k) is str assert type(uid) is int and type(k) is str
assert uid in vars.worldinfo_u and k in ( assert uid in vars.worldinfo_u and k in (
@ -1227,6 +1242,7 @@ def lua_set_attr(uid, k, v):
#==================================================================# #==================================================================#
# Get property of a world info folder given its UID and property name # Get property of a world info folder given its UID and property name
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_folder_get_attr(uid, k): def lua_folder_get_attr(uid, k):
assert type(uid) is int and type(k) is str assert type(uid) is int and type(k) is str
if(uid in vars.wifolders_d and k in ( if(uid in vars.wifolders_d and k in (
@ -1237,6 +1253,7 @@ def lua_folder_get_attr(uid, k):
#==================================================================# #==================================================================#
# Set property of a world info folder given its UID, property name and new value # Set property of a world info folder given its UID, property name and new value
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_folder_set_attr(uid, k, v): def lua_folder_set_attr(uid, k, v):
assert type(uid) is int and type(k) is str assert type(uid) is int and type(k) is str
assert uid in vars.wifolders_d and k in ( assert uid in vars.wifolders_d and k in (
@ -1251,12 +1268,14 @@ def lua_folder_set_attr(uid, k, v):
#==================================================================# #==================================================================#
# Get the "Amount to Generate" # Get the "Amount to Generate"
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_get_genamt(): def lua_get_genamt():
return vars.genamt return vars.genamt
#==================================================================# #==================================================================#
# Set the "Amount to Generate" # Set the "Amount to Generate"
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_set_genamt(genamt): def lua_set_genamt(genamt):
assert vars.lua_koboldbridge.userstate != "genmod" and type(genamt) in (int, float) and genamt >= 0 assert vars.lua_koboldbridge.userstate != "genmod" and type(genamt) in (int, float) and genamt >= 0
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set genamt to {int(genamt)}" + colors.END) print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set genamt to {int(genamt)}" + colors.END)
@ -1265,12 +1284,14 @@ def lua_set_genamt(genamt):
#==================================================================# #==================================================================#
# Get the "Gens Per Action" # Get the "Gens Per Action"
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_get_numseqs(): def lua_get_numseqs():
return vars.numseqs return vars.numseqs
#==================================================================# #==================================================================#
# Set the "Gens Per Action" # Set the "Gens Per Action"
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_set_numseqs(numseqs): def lua_set_numseqs(numseqs):
assert type(numseqs) in (int, float) and numseqs >= 1 assert type(numseqs) in (int, float) and numseqs >= 1
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set numseqs to {int(numseqs)}" + colors.END) print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set numseqs to {int(numseqs)}" + colors.END)
@ -1279,6 +1300,7 @@ def lua_set_numseqs(numseqs):
#==================================================================# #==================================================================#
# Check if a setting exists with the given name # Check if a setting exists with the given name
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_has_setting(setting): def lua_has_setting(setting):
return setting in ( return setting in (
"anotedepth", "anotedepth",
@ -1326,6 +1348,7 @@ def lua_has_setting(setting):
#==================================================================# #==================================================================#
# Return the setting with the given name if it exists # Return the setting with the given name if it exists
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_get_setting(setting): def lua_get_setting(setting):
if(setting in ("settemp", "temp")): return vars.temp if(setting in ("settemp", "temp")): return vars.temp
if(setting in ("settopp", "topp", "top_p")): return vars.top_p if(setting in ("settopp", "topp", "top_p")): return vars.top_p
@ -1350,6 +1373,7 @@ def lua_get_setting(setting):
#==================================================================# #==================================================================#
# Set the setting with the given name if it exists # Set the setting with the given name if it exists
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_set_setting(setting, v): def lua_set_setting(setting, v):
actual_type = type(lua_get_setting(setting)) actual_type = type(lua_get_setting(setting))
assert v is not None and (actual_type is type(v) or (actual_type is int and type(v) is float)) assert v is not None and (actual_type is type(v) or (actual_type is int and type(v) is float))
@ -1380,12 +1404,14 @@ def lua_set_setting(setting, v):
#==================================================================# #==================================================================#
# Get contents of memory # Get contents of memory
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_get_memory(): def lua_get_memory():
return vars.memory return vars.memory
#==================================================================# #==================================================================#
# Set contents of memory # Set contents of memory
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_set_memory(m): def lua_set_memory(m):
assert type(m) is str assert type(m) is str
vars.memory = m vars.memory = m
@ -1393,12 +1419,14 @@ def lua_set_memory(m):
#==================================================================# #==================================================================#
# Get contents of author's note # Get contents of author's note
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_get_authorsnote(): def lua_get_authorsnote():
return vars.authornote return vars.authornote
#==================================================================# #==================================================================#
# Set contents of author's note # Set contents of author's note
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_set_authorsnote(m): def lua_set_authorsnote(m):
assert type(m) is str assert type(m) is str
vars.authornote = m vars.authornote = m
@ -1406,12 +1434,14 @@ def lua_set_authorsnote(m):
#==================================================================# #==================================================================#
# Get contents of author's note template # Get contents of author's note template
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_get_authorsnotetemplate(): def lua_get_authorsnotetemplate():
return vars.authornotetemplate return vars.authornotetemplate
#==================================================================# #==================================================================#
# Set contents of author's note template # Set contents of author's note template
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_set_authorsnotetemplate(m): def lua_set_authorsnotetemplate(m):
assert type(m) is str assert type(m) is str
vars.authornotetemplate = m vars.authornotetemplate = m
@ -1419,6 +1449,7 @@ def lua_set_authorsnotetemplate(m):
#==================================================================# #==================================================================#
# Save settings and send them to client # Save settings and send them to client
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_resend_settings(): def lua_resend_settings():
settingschanged() settingschanged()
refresh_settings() refresh_settings()
@ -1426,6 +1457,7 @@ def lua_resend_settings():
#==================================================================# #==================================================================#
# Set story chunk text and delete the chunk if the new chunk is empty # Set story chunk text and delete the chunk if the new chunk is empty
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_set_chunk(k, v): def lua_set_chunk(k, v):
assert type(k) in (int, None) and type(v) is str assert type(k) in (int, None) and type(v) is str
assert k >= 0 assert k >= 0
@ -1458,6 +1490,7 @@ def lua_set_chunk(k, v):
#==================================================================# #==================================================================#
# Get model type as "gpt-2-xl", "gpt-neo-2.7B", etc. # Get model type as "gpt-2-xl", "gpt-neo-2.7B", etc.
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_get_modeltype(): def lua_get_modeltype():
if(vars.noai): if(vars.noai):
return "readonly" return "readonly"
@ -1486,6 +1519,7 @@ def lua_get_modeltype():
#==================================================================# #==================================================================#
# Get model backend as "transformers" or "mtj" # Get model backend as "transformers" or "mtj"
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_get_modelbackend(): def lua_get_modelbackend():
if(vars.noai): if(vars.noai):
return "readonly" return "readonly"
@ -1498,6 +1532,7 @@ def lua_get_modelbackend():
#==================================================================# #==================================================================#
# Check whether model is loaded from a custom path # Check whether model is loaded from a custom path
#==================================================================# #==================================================================#
@bridged_kwarg()
def lua_is_custommodel(): def lua_is_custommodel():
return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ") return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ")
@ -1558,35 +1593,10 @@ bridged = {
"userscript_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"), "userscript_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"),
"config_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"), "config_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"),
"lib_paths": vars.lua_state.table(os.path.join(os.path.dirname(os.path.realpath(__file__)), "lualibs"), os.path.join(os.path.dirname(os.path.realpath(__file__)), "extern", "lualibs")), "lib_paths": vars.lua_state.table(os.path.join(os.path.dirname(os.path.realpath(__file__)), "lualibs"), os.path.join(os.path.dirname(os.path.realpath(__file__)), "extern", "lualibs")),
"load_callback": load_callback,
"print": lua_print,
"warn": lua_warn,
"decode": lua_decode,
"encode": lua_encode,
"get_attr": lua_get_attr,
"set_attr": lua_set_attr,
"folder_get_attr": lua_folder_get_attr,
"folder_set_attr": lua_folder_set_attr,
"get_genamt": lua_get_genamt,
"set_genamt": lua_set_genamt,
"get_memory": lua_get_memory,
"set_memory": lua_set_memory,
"get_authorsnote": lua_get_authorsnote,
"set_authorsnote": lua_set_authorsnote,
"get_authorsnote": lua_get_authorsnotetemplate,
"set_authorsnote": lua_set_authorsnotetemplate,
"compute_context": lua_compute_context,
"get_numseqs": lua_get_numseqs,
"set_numseqs": lua_set_numseqs,
"has_setting": lua_has_setting,
"get_setting": lua_get_setting,
"set_setting": lua_set_setting,
"set_chunk": lua_set_chunk,
"get_modeltype": lua_get_modeltype,
"get_modelbackend": lua_get_modelbackend,
"is_custommodel": lua_is_custommodel,
"vars": vars, "vars": vars,
} }
for kwarg in _bridged:
bridged[kwarg] = _bridged[kwarg]
try: try:
vars.lua_kobold, vars.lua_koboldcore, vars.lua_koboldbridge = vars.lua_state.globals().dofile(os.path.join(os.path.dirname(os.path.realpath(__file__)), "bridge.lua"))( vars.lua_kobold, vars.lua_koboldcore, vars.lua_koboldbridge = vars.lua_state.globals().dofile(os.path.join(os.path.dirname(os.path.realpath(__file__)), "bridge.lua"))(
vars.lua_state.globals().python, vars.lua_state.globals().python,