Easier method of adding kwargs to bridged in aiserver.py
This commit is contained in:
parent
fbf5062074
commit
fc6caa0df0
66
aiserver.py
66
aiserver.py
|
@ -21,7 +21,7 @@ import zipfile
|
|||
import packaging
|
||||
import contextlib
|
||||
import traceback
|
||||
from typing import Any, Union, Dict, Set, List
|
||||
from typing import Any, Callable, Union, Dict, Set, List
|
||||
|
||||
import requests
|
||||
import html
|
||||
|
@ -1064,9 +1064,17 @@ if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
|
|||
def lua_log_format_name(name):
|
||||
return f"[{name}]" if type(name) is str else "CORE"
|
||||
|
||||
_bridged = {}
|
||||
def bridged_kwarg(name=None):
|
||||
def _bridged_kwarg(f: Callable):
|
||||
_bridged[name if name is not None else f.__name__[4:] if f.__name__[:4] == "lua_" else f.__name__] = f
|
||||
return f
|
||||
return _bridged_kwarg
|
||||
|
||||
#==================================================================#
|
||||
# Event triggered when a userscript is loaded
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def load_callback(filename, modulename):
|
||||
print(colors.GREEN + f"Loading Userscript [{modulename}] <{filename}>" + colors.END)
|
||||
|
||||
|
@ -1110,6 +1118,7 @@ def load_lua_scripts():
|
|||
#==================================================================#
|
||||
# Print message that originates from the userscript with the given name
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_print(msg):
|
||||
if(vars.lua_logname != vars.lua_koboldbridge.logging_name):
|
||||
vars.lua_logname = vars.lua_koboldbridge.logging_name
|
||||
|
@ -1119,6 +1128,7 @@ def lua_print(msg):
|
|||
#==================================================================#
|
||||
# Print warning that originates from the userscript with the given name
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_warn(msg):
|
||||
if(vars.lua_logname != vars.lua_koboldbridge.logging_name):
|
||||
vars.lua_logname = vars.lua_koboldbridge.logging_name
|
||||
|
@ -1128,6 +1138,7 @@ def lua_warn(msg):
|
|||
#==================================================================#
|
||||
# Decode tokens into a string using current tokenizer
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_decode(tokens):
|
||||
tokens = list(tokens.values())
|
||||
assert type(tokens) is list
|
||||
|
@ -1140,6 +1151,7 @@ def lua_decode(tokens):
|
|||
#==================================================================#
|
||||
# Encode string into list of token IDs using current tokenizer
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_encode(string):
|
||||
assert type(string) is str
|
||||
if("tokenizer" not in globals()):
|
||||
|
@ -1152,6 +1164,7 @@ def lua_encode(string):
|
|||
# Computes context given a submission, Lua array of entry UIDs and a Lua array
|
||||
# of folder UIDs
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_compute_context(submission, entries, folders, kwargs):
|
||||
assert type(submission) is str
|
||||
if(kwargs is None):
|
||||
|
@ -1190,6 +1203,7 @@ def lua_compute_context(submission, entries, folders, kwargs):
|
|||
#==================================================================#
|
||||
# Get property of a world info entry given its UID and property name
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_get_attr(uid, k):
|
||||
assert type(uid) is int and type(k) is str
|
||||
if(uid in vars.worldinfo_u and k in (
|
||||
|
@ -1208,6 +1222,7 @@ def lua_get_attr(uid, k):
|
|||
#==================================================================#
|
||||
# Set property of a world info entry given its UID, property name and new value
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_set_attr(uid, k, v):
|
||||
assert type(uid) is int and type(k) is str
|
||||
assert uid in vars.worldinfo_u and k in (
|
||||
|
@ -1227,6 +1242,7 @@ def lua_set_attr(uid, k, v):
|
|||
#==================================================================#
|
||||
# Get property of a world info folder given its UID and property name
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_folder_get_attr(uid, k):
|
||||
assert type(uid) is int and type(k) is str
|
||||
if(uid in vars.wifolders_d and k in (
|
||||
|
@ -1237,6 +1253,7 @@ def lua_folder_get_attr(uid, k):
|
|||
#==================================================================#
|
||||
# Set property of a world info folder given its UID, property name and new value
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_folder_set_attr(uid, k, v):
|
||||
assert type(uid) is int and type(k) is str
|
||||
assert uid in vars.wifolders_d and k in (
|
||||
|
@ -1251,12 +1268,14 @@ def lua_folder_set_attr(uid, k, v):
|
|||
#==================================================================#
|
||||
# Get the "Amount to Generate"
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_get_genamt():
|
||||
return vars.genamt
|
||||
|
||||
#==================================================================#
|
||||
# Set the "Amount to Generate"
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_set_genamt(genamt):
|
||||
assert vars.lua_koboldbridge.userstate != "genmod" and type(genamt) in (int, float) and genamt >= 0
|
||||
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set genamt to {int(genamt)}" + colors.END)
|
||||
|
@ -1265,12 +1284,14 @@ def lua_set_genamt(genamt):
|
|||
#==================================================================#
|
||||
# Get the "Gens Per Action"
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_get_numseqs():
|
||||
return vars.numseqs
|
||||
|
||||
#==================================================================#
|
||||
# Set the "Gens Per Action"
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_set_numseqs(numseqs):
|
||||
assert type(numseqs) in (int, float) and numseqs >= 1
|
||||
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set numseqs to {int(numseqs)}" + colors.END)
|
||||
|
@ -1279,6 +1300,7 @@ def lua_set_numseqs(numseqs):
|
|||
#==================================================================#
|
||||
# Check if a setting exists with the given name
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_has_setting(setting):
|
||||
return setting in (
|
||||
"anotedepth",
|
||||
|
@ -1326,6 +1348,7 @@ def lua_has_setting(setting):
|
|||
#==================================================================#
|
||||
# Return the setting with the given name if it exists
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_get_setting(setting):
|
||||
if(setting in ("settemp", "temp")): return vars.temp
|
||||
if(setting in ("settopp", "topp", "top_p")): return vars.top_p
|
||||
|
@ -1350,6 +1373,7 @@ def lua_get_setting(setting):
|
|||
#==================================================================#
|
||||
# Set the setting with the given name if it exists
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_set_setting(setting, v):
|
||||
actual_type = type(lua_get_setting(setting))
|
||||
assert v is not None and (actual_type is type(v) or (actual_type is int and type(v) is float))
|
||||
|
@ -1380,12 +1404,14 @@ def lua_set_setting(setting, v):
|
|||
#==================================================================#
|
||||
# Get contents of memory
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_get_memory():
|
||||
return vars.memory
|
||||
|
||||
#==================================================================#
|
||||
# Set contents of memory
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_set_memory(m):
|
||||
assert type(m) is str
|
||||
vars.memory = m
|
||||
|
@ -1393,12 +1419,14 @@ def lua_set_memory(m):
|
|||
#==================================================================#
|
||||
# Get contents of author's note
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_get_authorsnote():
|
||||
return vars.authornote
|
||||
|
||||
#==================================================================#
|
||||
# Set contents of author's note
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_set_authorsnote(m):
|
||||
assert type(m) is str
|
||||
vars.authornote = m
|
||||
|
@ -1406,12 +1434,14 @@ def lua_set_authorsnote(m):
|
|||
#==================================================================#
|
||||
# Get contents of author's note template
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_get_authorsnotetemplate():
|
||||
return vars.authornotetemplate
|
||||
|
||||
#==================================================================#
|
||||
# Set contents of author's note template
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_set_authorsnotetemplate(m):
|
||||
assert type(m) is str
|
||||
vars.authornotetemplate = m
|
||||
|
@ -1419,6 +1449,7 @@ def lua_set_authorsnotetemplate(m):
|
|||
#==================================================================#
|
||||
# Save settings and send them to client
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_resend_settings():
|
||||
settingschanged()
|
||||
refresh_settings()
|
||||
|
@ -1426,6 +1457,7 @@ def lua_resend_settings():
|
|||
#==================================================================#
|
||||
# Set story chunk text and delete the chunk if the new chunk is empty
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_set_chunk(k, v):
|
||||
assert type(k) in (int, None) and type(v) is str
|
||||
assert k >= 0
|
||||
|
@ -1458,6 +1490,7 @@ def lua_set_chunk(k, v):
|
|||
#==================================================================#
|
||||
# Get model type as "gpt-2-xl", "gpt-neo-2.7B", etc.
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_get_modeltype():
|
||||
if(vars.noai):
|
||||
return "readonly"
|
||||
|
@ -1486,6 +1519,7 @@ def lua_get_modeltype():
|
|||
#==================================================================#
|
||||
# Get model backend as "transformers" or "mtj"
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_get_modelbackend():
|
||||
if(vars.noai):
|
||||
return "readonly"
|
||||
|
@ -1498,6 +1532,7 @@ def lua_get_modelbackend():
|
|||
#==================================================================#
|
||||
# Check whether model is loaded from a custom path
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_is_custommodel():
|
||||
return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ")
|
||||
|
||||
|
@ -1558,35 +1593,10 @@ bridged = {
|
|||
"userscript_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"),
|
||||
"config_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"),
|
||||
"lib_paths": vars.lua_state.table(os.path.join(os.path.dirname(os.path.realpath(__file__)), "lualibs"), os.path.join(os.path.dirname(os.path.realpath(__file__)), "extern", "lualibs")),
|
||||
"load_callback": load_callback,
|
||||
"print": lua_print,
|
||||
"warn": lua_warn,
|
||||
"decode": lua_decode,
|
||||
"encode": lua_encode,
|
||||
"get_attr": lua_get_attr,
|
||||
"set_attr": lua_set_attr,
|
||||
"folder_get_attr": lua_folder_get_attr,
|
||||
"folder_set_attr": lua_folder_set_attr,
|
||||
"get_genamt": lua_get_genamt,
|
||||
"set_genamt": lua_set_genamt,
|
||||
"get_memory": lua_get_memory,
|
||||
"set_memory": lua_set_memory,
|
||||
"get_authorsnote": lua_get_authorsnote,
|
||||
"set_authorsnote": lua_set_authorsnote,
|
||||
"get_authorsnote": lua_get_authorsnotetemplate,
|
||||
"set_authorsnote": lua_set_authorsnotetemplate,
|
||||
"compute_context": lua_compute_context,
|
||||
"get_numseqs": lua_get_numseqs,
|
||||
"set_numseqs": lua_set_numseqs,
|
||||
"has_setting": lua_has_setting,
|
||||
"get_setting": lua_get_setting,
|
||||
"set_setting": lua_set_setting,
|
||||
"set_chunk": lua_set_chunk,
|
||||
"get_modeltype": lua_get_modeltype,
|
||||
"get_modelbackend": lua_get_modelbackend,
|
||||
"is_custommodel": lua_is_custommodel,
|
||||
"vars": vars,
|
||||
}
|
||||
for kwarg in _bridged:
|
||||
bridged[kwarg] = _bridged[kwarg]
|
||||
try:
|
||||
vars.lua_kobold, vars.lua_koboldcore, vars.lua_koboldbridge = vars.lua_state.globals().dofile(os.path.join(os.path.dirname(os.path.realpath(__file__)), "bridge.lua"))(
|
||||
vars.lua_state.globals().python,
|
||||
|
|
Loading…
Reference in New Issue