mirror of
				https://github.com/KoboldAI/KoboldAI-Client.git
				synced 2025-06-05 21:59:24 +02:00 
			
		
		
		
	Easier method of adding kwargs to bridged in aiserver.py
This commit is contained in:
		
							
								
								
									
										66
									
								
								aiserver.py
									
									
									
									
									
								
							
							
						
						
									
										66
									
								
								aiserver.py
									
									
									
									
									
								
							@@ -21,7 +21,7 @@ import zipfile
 | 
			
		||||
import packaging
 | 
			
		||||
import contextlib
 | 
			
		||||
import traceback
 | 
			
		||||
from typing import Any, Union, Dict, Set, List
 | 
			
		||||
from typing import Any, Callable, Union, Dict, Set, List
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import html
 | 
			
		||||
@@ -1064,9 +1064,17 @@ if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
 | 
			
		||||
def lua_log_format_name(name):
 | 
			
		||||
    return f"[{name}]" if type(name) is str else "CORE"
 | 
			
		||||
 | 
			
		||||
_bridged = {}
 | 
			
		||||
def bridged_kwarg(name=None):
 | 
			
		||||
    def _bridged_kwarg(f: Callable):
 | 
			
		||||
        _bridged[name if name is not None else f.__name__[4:] if f.__name__[:4] == "lua_" else f.__name__] = f
 | 
			
		||||
        return f
 | 
			
		||||
    return _bridged_kwarg
 | 
			
		||||
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Event triggered when a userscript is loaded
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def load_callback(filename, modulename):
 | 
			
		||||
    print(colors.GREEN + f"Loading Userscript [{modulename}] <{filename}>" + colors.END)
 | 
			
		||||
 | 
			
		||||
@@ -1110,6 +1118,7 @@ def load_lua_scripts():
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Print message that originates from the userscript with the given name
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_print(msg):
 | 
			
		||||
    if(vars.lua_logname != vars.lua_koboldbridge.logging_name):
 | 
			
		||||
        vars.lua_logname = vars.lua_koboldbridge.logging_name
 | 
			
		||||
@@ -1119,6 +1128,7 @@ def lua_print(msg):
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Print warning that originates from the userscript with the given name
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_warn(msg):
 | 
			
		||||
    if(vars.lua_logname != vars.lua_koboldbridge.logging_name):
 | 
			
		||||
        vars.lua_logname = vars.lua_koboldbridge.logging_name
 | 
			
		||||
@@ -1128,6 +1138,7 @@ def lua_warn(msg):
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Decode tokens into a string using current tokenizer
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_decode(tokens):
 | 
			
		||||
    tokens = list(tokens.values())
 | 
			
		||||
    assert type(tokens) is list
 | 
			
		||||
@@ -1140,6 +1151,7 @@ def lua_decode(tokens):
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Encode string into list of token IDs using current tokenizer
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_encode(string):
 | 
			
		||||
    assert type(string) is str
 | 
			
		||||
    if("tokenizer" not in globals()):
 | 
			
		||||
@@ -1152,6 +1164,7 @@ def lua_encode(string):
 | 
			
		||||
#  Computes context given a submission, Lua array of entry UIDs and a Lua array
 | 
			
		||||
#  of folder UIDs
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_compute_context(submission, entries, folders, kwargs):
 | 
			
		||||
    assert type(submission) is str
 | 
			
		||||
    if(kwargs is None):
 | 
			
		||||
@@ -1190,6 +1203,7 @@ def lua_compute_context(submission, entries, folders, kwargs):
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Get property of a world info entry given its UID and property name
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_get_attr(uid, k):
 | 
			
		||||
    assert type(uid) is int and type(k) is str
 | 
			
		||||
    if(uid in vars.worldinfo_u and k in (
 | 
			
		||||
@@ -1208,6 +1222,7 @@ def lua_get_attr(uid, k):
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Set property of a world info entry given its UID, property name and new value
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_set_attr(uid, k, v):
 | 
			
		||||
    assert type(uid) is int and type(k) is str
 | 
			
		||||
    assert uid in vars.worldinfo_u and k in (
 | 
			
		||||
@@ -1227,6 +1242,7 @@ def lua_set_attr(uid, k, v):
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Get property of a world info folder given its UID and property name
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_folder_get_attr(uid, k):
 | 
			
		||||
    assert type(uid) is int and type(k) is str
 | 
			
		||||
    if(uid in vars.wifolders_d and k in (
 | 
			
		||||
@@ -1237,6 +1253,7 @@ def lua_folder_get_attr(uid, k):
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Set property of a world info folder given its UID, property name and new value
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_folder_set_attr(uid, k, v):
 | 
			
		||||
    assert type(uid) is int and type(k) is str
 | 
			
		||||
    assert uid in vars.wifolders_d and k in (
 | 
			
		||||
@@ -1251,12 +1268,14 @@ def lua_folder_set_attr(uid, k, v):
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Get the "Amount to Generate"
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_get_genamt():
 | 
			
		||||
    return vars.genamt
 | 
			
		||||
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Set the "Amount to Generate"
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_set_genamt(genamt):
 | 
			
		||||
    assert vars.lua_koboldbridge.userstate != "genmod" and type(genamt) in (int, float) and genamt >= 0
 | 
			
		||||
    print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set genamt to {int(genamt)}" + colors.END)
 | 
			
		||||
@@ -1265,12 +1284,14 @@ def lua_set_genamt(genamt):
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Get the "Gens Per Action"
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_get_numseqs():
 | 
			
		||||
    return vars.numseqs
 | 
			
		||||
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Set the "Gens Per Action"
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_set_numseqs(numseqs):
 | 
			
		||||
    assert type(numseqs) in (int, float) and numseqs >= 1
 | 
			
		||||
    print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set numseqs to {int(numseqs)}" + colors.END)
 | 
			
		||||
@@ -1279,6 +1300,7 @@ def lua_set_numseqs(numseqs):
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Check if a setting exists with the given name
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_has_setting(setting):
 | 
			
		||||
    return setting in (
 | 
			
		||||
        "anotedepth",
 | 
			
		||||
@@ -1326,6 +1348,7 @@ def lua_has_setting(setting):
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Return the setting with the given name if it exists
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_get_setting(setting):
 | 
			
		||||
    if(setting in ("settemp", "temp")): return vars.temp
 | 
			
		||||
    if(setting in ("settopp", "topp", "top_p")): return vars.top_p
 | 
			
		||||
@@ -1350,6 +1373,7 @@ def lua_get_setting(setting):
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Set the setting with the given name if it exists
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_set_setting(setting, v):
 | 
			
		||||
    actual_type = type(lua_get_setting(setting))
 | 
			
		||||
    assert v is not None and (actual_type is type(v) or (actual_type is int and type(v) is float))
 | 
			
		||||
@@ -1380,12 +1404,14 @@ def lua_set_setting(setting, v):
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Get contents of memory
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_get_memory():
 | 
			
		||||
    return vars.memory
 | 
			
		||||
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Set contents of memory
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_set_memory(m):
 | 
			
		||||
    assert type(m) is str
 | 
			
		||||
    vars.memory = m
 | 
			
		||||
@@ -1393,12 +1419,14 @@ def lua_set_memory(m):
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Get contents of author's note
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_get_authorsnote():
 | 
			
		||||
    return vars.authornote
 | 
			
		||||
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Set contents of author's note
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_set_authorsnote(m):
 | 
			
		||||
    assert type(m) is str
 | 
			
		||||
    vars.authornote = m
 | 
			
		||||
@@ -1406,12 +1434,14 @@ def lua_set_authorsnote(m):
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Get contents of author's note template
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_get_authorsnotetemplate():
 | 
			
		||||
    return vars.authornotetemplate
 | 
			
		||||
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Set contents of author's note template
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_set_authorsnotetemplate(m):
 | 
			
		||||
    assert type(m) is str
 | 
			
		||||
    vars.authornotetemplate = m
 | 
			
		||||
@@ -1419,6 +1449,7 @@ def lua_set_authorsnotetemplate(m):
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Save settings and send them to client
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_resend_settings():
 | 
			
		||||
    settingschanged()
 | 
			
		||||
    refresh_settings()
 | 
			
		||||
@@ -1426,6 +1457,7 @@ def lua_resend_settings():
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Set story chunk text and delete the chunk if the new chunk is empty
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_set_chunk(k, v):
 | 
			
		||||
    assert type(k) in (int, None) and type(v) is str
 | 
			
		||||
    assert k >= 0
 | 
			
		||||
@@ -1458,6 +1490,7 @@ def lua_set_chunk(k, v):
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Get model type as "gpt-2-xl", "gpt-neo-2.7B", etc.
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_get_modeltype():
 | 
			
		||||
    if(vars.noai):
 | 
			
		||||
        return "readonly"
 | 
			
		||||
@@ -1486,6 +1519,7 @@ def lua_get_modeltype():
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Get model backend as "transformers" or "mtj"
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_get_modelbackend():
 | 
			
		||||
    if(vars.noai):
 | 
			
		||||
        return "readonly"
 | 
			
		||||
@@ -1498,6 +1532,7 @@ def lua_get_modelbackend():
 | 
			
		||||
#==================================================================#
 | 
			
		||||
#  Check whether model is loaded from a custom path
 | 
			
		||||
#==================================================================#
 | 
			
		||||
@bridged_kwarg()
 | 
			
		||||
def lua_is_custommodel():
 | 
			
		||||
    return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ")
 | 
			
		||||
 | 
			
		||||
@@ -1558,35 +1593,10 @@ bridged = {
 | 
			
		||||
    "userscript_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"),
 | 
			
		||||
    "config_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"),
 | 
			
		||||
    "lib_paths": vars.lua_state.table(os.path.join(os.path.dirname(os.path.realpath(__file__)), "lualibs"), os.path.join(os.path.dirname(os.path.realpath(__file__)), "extern", "lualibs")),
 | 
			
		||||
    "load_callback": load_callback,
 | 
			
		||||
    "print": lua_print,
 | 
			
		||||
    "warn": lua_warn,
 | 
			
		||||
    "decode": lua_decode,
 | 
			
		||||
    "encode": lua_encode,
 | 
			
		||||
    "get_attr": lua_get_attr,
 | 
			
		||||
    "set_attr": lua_set_attr,
 | 
			
		||||
    "folder_get_attr": lua_folder_get_attr,
 | 
			
		||||
    "folder_set_attr": lua_folder_set_attr,
 | 
			
		||||
    "get_genamt": lua_get_genamt,
 | 
			
		||||
    "set_genamt": lua_set_genamt,
 | 
			
		||||
    "get_memory": lua_get_memory,
 | 
			
		||||
    "set_memory": lua_set_memory,
 | 
			
		||||
    "get_authorsnote": lua_get_authorsnote,
 | 
			
		||||
    "set_authorsnote": lua_set_authorsnote,
 | 
			
		||||
    "get_authorsnote": lua_get_authorsnotetemplate,
 | 
			
		||||
    "set_authorsnote": lua_set_authorsnotetemplate,
 | 
			
		||||
    "compute_context": lua_compute_context,
 | 
			
		||||
    "get_numseqs": lua_get_numseqs,
 | 
			
		||||
    "set_numseqs": lua_set_numseqs,
 | 
			
		||||
    "has_setting": lua_has_setting,
 | 
			
		||||
    "get_setting": lua_get_setting,
 | 
			
		||||
    "set_setting": lua_set_setting,
 | 
			
		||||
    "set_chunk": lua_set_chunk,
 | 
			
		||||
    "get_modeltype": lua_get_modeltype,
 | 
			
		||||
    "get_modelbackend": lua_get_modelbackend,
 | 
			
		||||
    "is_custommodel": lua_is_custommodel,
 | 
			
		||||
    "vars": vars,
 | 
			
		||||
}
 | 
			
		||||
for kwarg in _bridged:
 | 
			
		||||
    bridged[kwarg] = _bridged[kwarg]
 | 
			
		||||
try:
 | 
			
		||||
    vars.lua_kobold, vars.lua_koboldcore, vars.lua_koboldbridge = vars.lua_state.globals().dofile(os.path.join(os.path.dirname(os.path.realpath(__file__)), "bridge.lua"))(
 | 
			
		||||
        vars.lua_state.globals().python,
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user