Easier method of adding kwargs to bridged in aiserver.py

This commit is contained in:
Gnome Ann 2022-01-04 19:36:21 -05:00
parent fbf5062074
commit fc6caa0df0
1 changed files with 38 additions and 28 deletions

View File

@ -21,7 +21,7 @@ import zipfile
import packaging
import contextlib
import traceback
from typing import Any, Union, Dict, Set, List
from typing import Any, Callable, Union, Dict, Set, List
import requests
import html
@ -1064,9 +1064,17 @@ if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
def lua_log_format_name(name):
return f"[{name}]" if type(name) is str else "CORE"
_bridged = {}
def bridged_kwarg(name=None):
def _bridged_kwarg(f: Callable):
_bridged[name if name is not None else f.__name__[4:] if f.__name__[:4] == "lua_" else f.__name__] = f
return f
return _bridged_kwarg
#==================================================================#
# Event triggered when a userscript is loaded
#==================================================================#
@bridged_kwarg()
def load_callback(filename, modulename):
print(colors.GREEN + f"Loading Userscript [{modulename}] <{filename}>" + colors.END)
@ -1110,6 +1118,7 @@ def load_lua_scripts():
#==================================================================#
# Print message that originates from the userscript with the given name
#==================================================================#
@bridged_kwarg()
def lua_print(msg):
if(vars.lua_logname != vars.lua_koboldbridge.logging_name):
vars.lua_logname = vars.lua_koboldbridge.logging_name
@ -1119,6 +1128,7 @@ def lua_print(msg):
#==================================================================#
# Print warning that originates from the userscript with the given name
#==================================================================#
@bridged_kwarg()
def lua_warn(msg):
if(vars.lua_logname != vars.lua_koboldbridge.logging_name):
vars.lua_logname = vars.lua_koboldbridge.logging_name
@ -1128,6 +1138,7 @@ def lua_warn(msg):
#==================================================================#
# Decode tokens into a string using current tokenizer
#==================================================================#
@bridged_kwarg()
def lua_decode(tokens):
tokens = list(tokens.values())
assert type(tokens) is list
@ -1140,6 +1151,7 @@ def lua_decode(tokens):
#==================================================================#
# Encode string into list of token IDs using current tokenizer
#==================================================================#
@bridged_kwarg()
def lua_encode(string):
assert type(string) is str
if("tokenizer" not in globals()):
@ -1152,6 +1164,7 @@ def lua_encode(string):
# Computes context given a submission, Lua array of entry UIDs and a Lua array
# of folder UIDs
#==================================================================#
@bridged_kwarg()
def lua_compute_context(submission, entries, folders, kwargs):
assert type(submission) is str
if(kwargs is None):
@ -1190,6 +1203,7 @@ def lua_compute_context(submission, entries, folders, kwargs):
#==================================================================#
# Get property of a world info entry given its UID and property name
#==================================================================#
@bridged_kwarg()
def lua_get_attr(uid, k):
assert type(uid) is int and type(k) is str
if(uid in vars.worldinfo_u and k in (
@ -1208,6 +1222,7 @@ def lua_get_attr(uid, k):
#==================================================================#
# Set property of a world info entry given its UID, property name and new value
#==================================================================#
@bridged_kwarg()
def lua_set_attr(uid, k, v):
assert type(uid) is int and type(k) is str
assert uid in vars.worldinfo_u and k in (
@ -1227,6 +1242,7 @@ def lua_set_attr(uid, k, v):
#==================================================================#
# Get property of a world info folder given its UID and property name
#==================================================================#
@bridged_kwarg()
def lua_folder_get_attr(uid, k):
assert type(uid) is int and type(k) is str
if(uid in vars.wifolders_d and k in (
@ -1237,6 +1253,7 @@ def lua_folder_get_attr(uid, k):
#==================================================================#
# Set property of a world info folder given its UID, property name and new value
#==================================================================#
@bridged_kwarg()
def lua_folder_set_attr(uid, k, v):
assert type(uid) is int and type(k) is str
assert uid in vars.wifolders_d and k in (
@ -1251,12 +1268,14 @@ def lua_folder_set_attr(uid, k, v):
#==================================================================#
# Get the "Amount to Generate"
#==================================================================#
@bridged_kwarg()
def lua_get_genamt():
return vars.genamt
#==================================================================#
# Set the "Amount to Generate"
#==================================================================#
@bridged_kwarg()
def lua_set_genamt(genamt):
assert vars.lua_koboldbridge.userstate != "genmod" and type(genamt) in (int, float) and genamt >= 0
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set genamt to {int(genamt)}" + colors.END)
@ -1265,12 +1284,14 @@ def lua_set_genamt(genamt):
#==================================================================#
# Get the "Gens Per Action"
#==================================================================#
@bridged_kwarg()
def lua_get_numseqs():
return vars.numseqs
#==================================================================#
# Set the "Gens Per Action"
#==================================================================#
@bridged_kwarg()
def lua_set_numseqs(numseqs):
assert type(numseqs) in (int, float) and numseqs >= 1
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set numseqs to {int(numseqs)}" + colors.END)
@ -1279,6 +1300,7 @@ def lua_set_numseqs(numseqs):
#==================================================================#
# Check if a setting exists with the given name
#==================================================================#
@bridged_kwarg()
def lua_has_setting(setting):
return setting in (
"anotedepth",
@ -1326,6 +1348,7 @@ def lua_has_setting(setting):
#==================================================================#
# Return the setting with the given name if it exists
#==================================================================#
@bridged_kwarg()
def lua_get_setting(setting):
if(setting in ("settemp", "temp")): return vars.temp
if(setting in ("settopp", "topp", "top_p")): return vars.top_p
@ -1350,6 +1373,7 @@ def lua_get_setting(setting):
#==================================================================#
# Set the setting with the given name if it exists
#==================================================================#
@bridged_kwarg()
def lua_set_setting(setting, v):
actual_type = type(lua_get_setting(setting))
assert v is not None and (actual_type is type(v) or (actual_type is int and type(v) is float))
@ -1380,12 +1404,14 @@ def lua_set_setting(setting, v):
#==================================================================#
# Get contents of memory
#==================================================================#
@bridged_kwarg()
def lua_get_memory():
return vars.memory
#==================================================================#
# Set contents of memory
#==================================================================#
@bridged_kwarg()
def lua_set_memory(m):
assert type(m) is str
vars.memory = m
@ -1393,12 +1419,14 @@ def lua_set_memory(m):
#==================================================================#
# Get contents of author's note
#==================================================================#
@bridged_kwarg()
def lua_get_authorsnote():
return vars.authornote
#==================================================================#
# Set contents of author's note
#==================================================================#
@bridged_kwarg()
def lua_set_authorsnote(m):
assert type(m) is str
vars.authornote = m
@ -1406,12 +1434,14 @@ def lua_set_authorsnote(m):
#==================================================================#
# Get contents of author's note template
#==================================================================#
@bridged_kwarg()
def lua_get_authorsnotetemplate():
return vars.authornotetemplate
#==================================================================#
# Set contents of author's note template
#==================================================================#
@bridged_kwarg()
def lua_set_authorsnotetemplate(m):
assert type(m) is str
vars.authornotetemplate = m
@ -1419,6 +1449,7 @@ def lua_set_authorsnotetemplate(m):
#==================================================================#
# Save settings and send them to client
#==================================================================#
@bridged_kwarg()
def lua_resend_settings():
settingschanged()
refresh_settings()
@ -1426,6 +1457,7 @@ def lua_resend_settings():
#==================================================================#
# Set story chunk text and delete the chunk if the new chunk is empty
#==================================================================#
@bridged_kwarg()
def lua_set_chunk(k, v):
assert type(k) in (int, None) and type(v) is str
assert k >= 0
@ -1458,6 +1490,7 @@ def lua_set_chunk(k, v):
#==================================================================#
# Get model type as "gpt-2-xl", "gpt-neo-2.7B", etc.
#==================================================================#
@bridged_kwarg()
def lua_get_modeltype():
if(vars.noai):
return "readonly"
@ -1486,6 +1519,7 @@ def lua_get_modeltype():
#==================================================================#
# Get model backend as "transformers" or "mtj"
#==================================================================#
@bridged_kwarg()
def lua_get_modelbackend():
if(vars.noai):
return "readonly"
@ -1498,6 +1532,7 @@ def lua_get_modelbackend():
#==================================================================#
# Check whether model is loaded from a custom path
#==================================================================#
@bridged_kwarg()
def lua_is_custommodel():
return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ")
@ -1558,35 +1593,10 @@ bridged = {
"userscript_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"),
"config_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"),
"lib_paths": vars.lua_state.table(os.path.join(os.path.dirname(os.path.realpath(__file__)), "lualibs"), os.path.join(os.path.dirname(os.path.realpath(__file__)), "extern", "lualibs")),
"load_callback": load_callback,
"print": lua_print,
"warn": lua_warn,
"decode": lua_decode,
"encode": lua_encode,
"get_attr": lua_get_attr,
"set_attr": lua_set_attr,
"folder_get_attr": lua_folder_get_attr,
"folder_set_attr": lua_folder_set_attr,
"get_genamt": lua_get_genamt,
"set_genamt": lua_set_genamt,
"get_memory": lua_get_memory,
"set_memory": lua_set_memory,
"get_authorsnote": lua_get_authorsnote,
"set_authorsnote": lua_set_authorsnote,
"get_authorsnote": lua_get_authorsnotetemplate,
"set_authorsnote": lua_set_authorsnotetemplate,
"compute_context": lua_compute_context,
"get_numseqs": lua_get_numseqs,
"set_numseqs": lua_set_numseqs,
"has_setting": lua_has_setting,
"get_setting": lua_get_setting,
"set_setting": lua_set_setting,
"set_chunk": lua_set_chunk,
"get_modeltype": lua_get_modeltype,
"get_modelbackend": lua_get_modelbackend,
"is_custommodel": lua_is_custommodel,
"vars": vars,
}
for kwarg in _bridged:
bridged[kwarg] = _bridged[kwarg]
try:
vars.lua_kobold, vars.lua_koboldcore, vars.lua_koboldbridge = vars.lua_state.globals().dofile(os.path.join(os.path.dirname(os.path.realpath(__file__)), "bridge.lua"))(
vars.lua_state.globals().python,