Easier method of adding kwargs to bridged in aiserver.py
This commit is contained in:
parent
fbf5062074
commit
fc6caa0df0
66
aiserver.py
66
aiserver.py
|
@ -21,7 +21,7 @@ import zipfile
|
||||||
import packaging
|
import packaging
|
||||||
import contextlib
|
import contextlib
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Any, Union, Dict, Set, List
|
from typing import Any, Callable, Union, Dict, Set, List
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import html
|
import html
|
||||||
|
@ -1064,9 +1064,17 @@ if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
|
||||||
def lua_log_format_name(name):
|
def lua_log_format_name(name):
|
||||||
return f"[{name}]" if type(name) is str else "CORE"
|
return f"[{name}]" if type(name) is str else "CORE"
|
||||||
|
|
||||||
|
_bridged = {}
|
||||||
|
def bridged_kwarg(name=None):
|
||||||
|
def _bridged_kwarg(f: Callable):
|
||||||
|
_bridged[name if name is not None else f.__name__[4:] if f.__name__[:4] == "lua_" else f.__name__] = f
|
||||||
|
return f
|
||||||
|
return _bridged_kwarg
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Event triggered when a userscript is loaded
|
# Event triggered when a userscript is loaded
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def load_callback(filename, modulename):
|
def load_callback(filename, modulename):
|
||||||
print(colors.GREEN + f"Loading Userscript [{modulename}] <{filename}>" + colors.END)
|
print(colors.GREEN + f"Loading Userscript [{modulename}] <{filename}>" + colors.END)
|
||||||
|
|
||||||
|
@ -1110,6 +1118,7 @@ def load_lua_scripts():
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Print message that originates from the userscript with the given name
|
# Print message that originates from the userscript with the given name
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_print(msg):
|
def lua_print(msg):
|
||||||
if(vars.lua_logname != vars.lua_koboldbridge.logging_name):
|
if(vars.lua_logname != vars.lua_koboldbridge.logging_name):
|
||||||
vars.lua_logname = vars.lua_koboldbridge.logging_name
|
vars.lua_logname = vars.lua_koboldbridge.logging_name
|
||||||
|
@ -1119,6 +1128,7 @@ def lua_print(msg):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Print warning that originates from the userscript with the given name
|
# Print warning that originates from the userscript with the given name
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_warn(msg):
|
def lua_warn(msg):
|
||||||
if(vars.lua_logname != vars.lua_koboldbridge.logging_name):
|
if(vars.lua_logname != vars.lua_koboldbridge.logging_name):
|
||||||
vars.lua_logname = vars.lua_koboldbridge.logging_name
|
vars.lua_logname = vars.lua_koboldbridge.logging_name
|
||||||
|
@ -1128,6 +1138,7 @@ def lua_warn(msg):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Decode tokens into a string using current tokenizer
|
# Decode tokens into a string using current tokenizer
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_decode(tokens):
|
def lua_decode(tokens):
|
||||||
tokens = list(tokens.values())
|
tokens = list(tokens.values())
|
||||||
assert type(tokens) is list
|
assert type(tokens) is list
|
||||||
|
@ -1140,6 +1151,7 @@ def lua_decode(tokens):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Encode string into list of token IDs using current tokenizer
|
# Encode string into list of token IDs using current tokenizer
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_encode(string):
|
def lua_encode(string):
|
||||||
assert type(string) is str
|
assert type(string) is str
|
||||||
if("tokenizer" not in globals()):
|
if("tokenizer" not in globals()):
|
||||||
|
@ -1152,6 +1164,7 @@ def lua_encode(string):
|
||||||
# Computes context given a submission, Lua array of entry UIDs and a Lua array
|
# Computes context given a submission, Lua array of entry UIDs and a Lua array
|
||||||
# of folder UIDs
|
# of folder UIDs
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_compute_context(submission, entries, folders, kwargs):
|
def lua_compute_context(submission, entries, folders, kwargs):
|
||||||
assert type(submission) is str
|
assert type(submission) is str
|
||||||
if(kwargs is None):
|
if(kwargs is None):
|
||||||
|
@ -1190,6 +1203,7 @@ def lua_compute_context(submission, entries, folders, kwargs):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Get property of a world info entry given its UID and property name
|
# Get property of a world info entry given its UID and property name
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_get_attr(uid, k):
|
def lua_get_attr(uid, k):
|
||||||
assert type(uid) is int and type(k) is str
|
assert type(uid) is int and type(k) is str
|
||||||
if(uid in vars.worldinfo_u and k in (
|
if(uid in vars.worldinfo_u and k in (
|
||||||
|
@ -1208,6 +1222,7 @@ def lua_get_attr(uid, k):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Set property of a world info entry given its UID, property name and new value
|
# Set property of a world info entry given its UID, property name and new value
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_set_attr(uid, k, v):
|
def lua_set_attr(uid, k, v):
|
||||||
assert type(uid) is int and type(k) is str
|
assert type(uid) is int and type(k) is str
|
||||||
assert uid in vars.worldinfo_u and k in (
|
assert uid in vars.worldinfo_u and k in (
|
||||||
|
@ -1227,6 +1242,7 @@ def lua_set_attr(uid, k, v):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Get property of a world info folder given its UID and property name
|
# Get property of a world info folder given its UID and property name
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_folder_get_attr(uid, k):
|
def lua_folder_get_attr(uid, k):
|
||||||
assert type(uid) is int and type(k) is str
|
assert type(uid) is int and type(k) is str
|
||||||
if(uid in vars.wifolders_d and k in (
|
if(uid in vars.wifolders_d and k in (
|
||||||
|
@ -1237,6 +1253,7 @@ def lua_folder_get_attr(uid, k):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Set property of a world info folder given its UID, property name and new value
|
# Set property of a world info folder given its UID, property name and new value
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_folder_set_attr(uid, k, v):
|
def lua_folder_set_attr(uid, k, v):
|
||||||
assert type(uid) is int and type(k) is str
|
assert type(uid) is int and type(k) is str
|
||||||
assert uid in vars.wifolders_d and k in (
|
assert uid in vars.wifolders_d and k in (
|
||||||
|
@ -1251,12 +1268,14 @@ def lua_folder_set_attr(uid, k, v):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Get the "Amount to Generate"
|
# Get the "Amount to Generate"
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_get_genamt():
|
def lua_get_genamt():
|
||||||
return vars.genamt
|
return vars.genamt
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Set the "Amount to Generate"
|
# Set the "Amount to Generate"
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_set_genamt(genamt):
|
def lua_set_genamt(genamt):
|
||||||
assert vars.lua_koboldbridge.userstate != "genmod" and type(genamt) in (int, float) and genamt >= 0
|
assert vars.lua_koboldbridge.userstate != "genmod" and type(genamt) in (int, float) and genamt >= 0
|
||||||
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set genamt to {int(genamt)}" + colors.END)
|
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set genamt to {int(genamt)}" + colors.END)
|
||||||
|
@ -1265,12 +1284,14 @@ def lua_set_genamt(genamt):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Get the "Gens Per Action"
|
# Get the "Gens Per Action"
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_get_numseqs():
|
def lua_get_numseqs():
|
||||||
return vars.numseqs
|
return vars.numseqs
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Set the "Gens Per Action"
|
# Set the "Gens Per Action"
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_set_numseqs(numseqs):
|
def lua_set_numseqs(numseqs):
|
||||||
assert type(numseqs) in (int, float) and numseqs >= 1
|
assert type(numseqs) in (int, float) and numseqs >= 1
|
||||||
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set numseqs to {int(numseqs)}" + colors.END)
|
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set numseqs to {int(numseqs)}" + colors.END)
|
||||||
|
@ -1279,6 +1300,7 @@ def lua_set_numseqs(numseqs):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Check if a setting exists with the given name
|
# Check if a setting exists with the given name
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_has_setting(setting):
|
def lua_has_setting(setting):
|
||||||
return setting in (
|
return setting in (
|
||||||
"anotedepth",
|
"anotedepth",
|
||||||
|
@ -1326,6 +1348,7 @@ def lua_has_setting(setting):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Return the setting with the given name if it exists
|
# Return the setting with the given name if it exists
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_get_setting(setting):
|
def lua_get_setting(setting):
|
||||||
if(setting in ("settemp", "temp")): return vars.temp
|
if(setting in ("settemp", "temp")): return vars.temp
|
||||||
if(setting in ("settopp", "topp", "top_p")): return vars.top_p
|
if(setting in ("settopp", "topp", "top_p")): return vars.top_p
|
||||||
|
@ -1350,6 +1373,7 @@ def lua_get_setting(setting):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Set the setting with the given name if it exists
|
# Set the setting with the given name if it exists
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_set_setting(setting, v):
|
def lua_set_setting(setting, v):
|
||||||
actual_type = type(lua_get_setting(setting))
|
actual_type = type(lua_get_setting(setting))
|
||||||
assert v is not None and (actual_type is type(v) or (actual_type is int and type(v) is float))
|
assert v is not None and (actual_type is type(v) or (actual_type is int and type(v) is float))
|
||||||
|
@ -1380,12 +1404,14 @@ def lua_set_setting(setting, v):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Get contents of memory
|
# Get contents of memory
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_get_memory():
|
def lua_get_memory():
|
||||||
return vars.memory
|
return vars.memory
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Set contents of memory
|
# Set contents of memory
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_set_memory(m):
|
def lua_set_memory(m):
|
||||||
assert type(m) is str
|
assert type(m) is str
|
||||||
vars.memory = m
|
vars.memory = m
|
||||||
|
@ -1393,12 +1419,14 @@ def lua_set_memory(m):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Get contents of author's note
|
# Get contents of author's note
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_get_authorsnote():
|
def lua_get_authorsnote():
|
||||||
return vars.authornote
|
return vars.authornote
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Set contents of author's note
|
# Set contents of author's note
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_set_authorsnote(m):
|
def lua_set_authorsnote(m):
|
||||||
assert type(m) is str
|
assert type(m) is str
|
||||||
vars.authornote = m
|
vars.authornote = m
|
||||||
|
@ -1406,12 +1434,14 @@ def lua_set_authorsnote(m):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Get contents of author's note template
|
# Get contents of author's note template
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_get_authorsnotetemplate():
|
def lua_get_authorsnotetemplate():
|
||||||
return vars.authornotetemplate
|
return vars.authornotetemplate
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Set contents of author's note template
|
# Set contents of author's note template
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_set_authorsnotetemplate(m):
|
def lua_set_authorsnotetemplate(m):
|
||||||
assert type(m) is str
|
assert type(m) is str
|
||||||
vars.authornotetemplate = m
|
vars.authornotetemplate = m
|
||||||
|
@ -1419,6 +1449,7 @@ def lua_set_authorsnotetemplate(m):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Save settings and send them to client
|
# Save settings and send them to client
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_resend_settings():
|
def lua_resend_settings():
|
||||||
settingschanged()
|
settingschanged()
|
||||||
refresh_settings()
|
refresh_settings()
|
||||||
|
@ -1426,6 +1457,7 @@ def lua_resend_settings():
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Set story chunk text and delete the chunk if the new chunk is empty
|
# Set story chunk text and delete the chunk if the new chunk is empty
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_set_chunk(k, v):
|
def lua_set_chunk(k, v):
|
||||||
assert type(k) in (int, None) and type(v) is str
|
assert type(k) in (int, None) and type(v) is str
|
||||||
assert k >= 0
|
assert k >= 0
|
||||||
|
@ -1458,6 +1490,7 @@ def lua_set_chunk(k, v):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Get model type as "gpt-2-xl", "gpt-neo-2.7B", etc.
|
# Get model type as "gpt-2-xl", "gpt-neo-2.7B", etc.
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_get_modeltype():
|
def lua_get_modeltype():
|
||||||
if(vars.noai):
|
if(vars.noai):
|
||||||
return "readonly"
|
return "readonly"
|
||||||
|
@ -1486,6 +1519,7 @@ def lua_get_modeltype():
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Get model backend as "transformers" or "mtj"
|
# Get model backend as "transformers" or "mtj"
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_get_modelbackend():
|
def lua_get_modelbackend():
|
||||||
if(vars.noai):
|
if(vars.noai):
|
||||||
return "readonly"
|
return "readonly"
|
||||||
|
@ -1498,6 +1532,7 @@ def lua_get_modelbackend():
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Check whether model is loaded from a custom path
|
# Check whether model is loaded from a custom path
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@bridged_kwarg()
|
||||||
def lua_is_custommodel():
|
def lua_is_custommodel():
|
||||||
return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ")
|
return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ")
|
||||||
|
|
||||||
|
@ -1558,35 +1593,10 @@ bridged = {
|
||||||
"userscript_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"),
|
"userscript_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"),
|
||||||
"config_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"),
|
"config_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"),
|
||||||
"lib_paths": vars.lua_state.table(os.path.join(os.path.dirname(os.path.realpath(__file__)), "lualibs"), os.path.join(os.path.dirname(os.path.realpath(__file__)), "extern", "lualibs")),
|
"lib_paths": vars.lua_state.table(os.path.join(os.path.dirname(os.path.realpath(__file__)), "lualibs"), os.path.join(os.path.dirname(os.path.realpath(__file__)), "extern", "lualibs")),
|
||||||
"load_callback": load_callback,
|
|
||||||
"print": lua_print,
|
|
||||||
"warn": lua_warn,
|
|
||||||
"decode": lua_decode,
|
|
||||||
"encode": lua_encode,
|
|
||||||
"get_attr": lua_get_attr,
|
|
||||||
"set_attr": lua_set_attr,
|
|
||||||
"folder_get_attr": lua_folder_get_attr,
|
|
||||||
"folder_set_attr": lua_folder_set_attr,
|
|
||||||
"get_genamt": lua_get_genamt,
|
|
||||||
"set_genamt": lua_set_genamt,
|
|
||||||
"get_memory": lua_get_memory,
|
|
||||||
"set_memory": lua_set_memory,
|
|
||||||
"get_authorsnote": lua_get_authorsnote,
|
|
||||||
"set_authorsnote": lua_set_authorsnote,
|
|
||||||
"get_authorsnote": lua_get_authorsnotetemplate,
|
|
||||||
"set_authorsnote": lua_set_authorsnotetemplate,
|
|
||||||
"compute_context": lua_compute_context,
|
|
||||||
"get_numseqs": lua_get_numseqs,
|
|
||||||
"set_numseqs": lua_set_numseqs,
|
|
||||||
"has_setting": lua_has_setting,
|
|
||||||
"get_setting": lua_get_setting,
|
|
||||||
"set_setting": lua_set_setting,
|
|
||||||
"set_chunk": lua_set_chunk,
|
|
||||||
"get_modeltype": lua_get_modeltype,
|
|
||||||
"get_modelbackend": lua_get_modelbackend,
|
|
||||||
"is_custommodel": lua_is_custommodel,
|
|
||||||
"vars": vars,
|
"vars": vars,
|
||||||
}
|
}
|
||||||
|
for kwarg in _bridged:
|
||||||
|
bridged[kwarg] = _bridged[kwarg]
|
||||||
try:
|
try:
|
||||||
vars.lua_kobold, vars.lua_koboldcore, vars.lua_koboldbridge = vars.lua_state.globals().dofile(os.path.join(os.path.dirname(os.path.realpath(__file__)), "bridge.lua"))(
|
vars.lua_kobold, vars.lua_koboldcore, vars.lua_koboldbridge = vars.lua_state.globals().dofile(os.path.join(os.path.dirname(os.path.realpath(__file__)), "bridge.lua"))(
|
||||||
vars.lua_state.globals().python,
|
vars.lua_state.globals().python,
|
||||||
|
|
Loading…
Reference in New Issue