mirror of
				https://github.com/KoboldAI/KoboldAI-Client.git
				synced 2025-06-05 21:59:24 +02:00 
			
		
		
		
	Easier method of adding kwargs to bridged in aiserver.py
This commit is contained in:
		
							
								
								
									
										66
									
								
								aiserver.py
									
									
									
									
									
								
							
							
						
						
									
										66
									
								
								aiserver.py
									
									
									
									
									
								
							| @@ -21,7 +21,7 @@ import zipfile | ||||
| import packaging | ||||
| import contextlib | ||||
| import traceback | ||||
| from typing import Any, Union, Dict, Set, List | ||||
| from typing import Any, Callable, Union, Dict, Set, List | ||||
|  | ||||
| import requests | ||||
| import html | ||||
| @@ -1064,9 +1064,17 @@ if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")): | ||||
| def lua_log_format_name(name): | ||||
|     return f"[{name}]" if type(name) is str else "CORE" | ||||
|  | ||||
| _bridged = {} | ||||
| def bridged_kwarg(name=None): | ||||
|     def _bridged_kwarg(f: Callable): | ||||
|         _bridged[name if name is not None else f.__name__[4:] if f.__name__[:4] == "lua_" else f.__name__] = f | ||||
|         return f | ||||
|     return _bridged_kwarg | ||||
|  | ||||
| #==================================================================# | ||||
| #  Event triggered when a userscript is loaded | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def load_callback(filename, modulename): | ||||
|     print(colors.GREEN + f"Loading Userscript [{modulename}] <{filename}>" + colors.END) | ||||
|  | ||||
| @@ -1110,6 +1118,7 @@ def load_lua_scripts(): | ||||
| #==================================================================# | ||||
| #  Print message that originates from the userscript with the given name | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_print(msg): | ||||
|     if(vars.lua_logname != vars.lua_koboldbridge.logging_name): | ||||
|         vars.lua_logname = vars.lua_koboldbridge.logging_name | ||||
| @@ -1119,6 +1128,7 @@ def lua_print(msg): | ||||
| #==================================================================# | ||||
| #  Print warning that originates from the userscript with the given name | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_warn(msg): | ||||
|     if(vars.lua_logname != vars.lua_koboldbridge.logging_name): | ||||
|         vars.lua_logname = vars.lua_koboldbridge.logging_name | ||||
| @@ -1128,6 +1138,7 @@ def lua_warn(msg): | ||||
| #==================================================================# | ||||
| #  Decode tokens into a string using current tokenizer | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_decode(tokens): | ||||
|     tokens = list(tokens.values()) | ||||
|     assert type(tokens) is list | ||||
| @@ -1140,6 +1151,7 @@ def lua_decode(tokens): | ||||
| #==================================================================# | ||||
| #  Encode string into list of token IDs using current tokenizer | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_encode(string): | ||||
|     assert type(string) is str | ||||
|     if("tokenizer" not in globals()): | ||||
| @@ -1152,6 +1164,7 @@ def lua_encode(string): | ||||
| #  Computes context given a submission, Lua array of entry UIDs and a Lua array | ||||
| #  of folder UIDs | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_compute_context(submission, entries, folders, kwargs): | ||||
|     assert type(submission) is str | ||||
|     if(kwargs is None): | ||||
| @@ -1190,6 +1203,7 @@ def lua_compute_context(submission, entries, folders, kwargs): | ||||
| #==================================================================# | ||||
| #  Get property of a world info entry given its UID and property name | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_get_attr(uid, k): | ||||
|     assert type(uid) is int and type(k) is str | ||||
|     if(uid in vars.worldinfo_u and k in ( | ||||
| @@ -1208,6 +1222,7 @@ def lua_get_attr(uid, k): | ||||
| #==================================================================# | ||||
| #  Set property of a world info entry given its UID, property name and new value | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_set_attr(uid, k, v): | ||||
|     assert type(uid) is int and type(k) is str | ||||
|     assert uid in vars.worldinfo_u and k in ( | ||||
| @@ -1227,6 +1242,7 @@ def lua_set_attr(uid, k, v): | ||||
| #==================================================================# | ||||
| #  Get property of a world info folder given its UID and property name | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_folder_get_attr(uid, k): | ||||
|     assert type(uid) is int and type(k) is str | ||||
|     if(uid in vars.wifolders_d and k in ( | ||||
| @@ -1237,6 +1253,7 @@ def lua_folder_get_attr(uid, k): | ||||
| #==================================================================# | ||||
| #  Set property of a world info folder given its UID, property name and new value | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_folder_set_attr(uid, k, v): | ||||
|     assert type(uid) is int and type(k) is str | ||||
|     assert uid in vars.wifolders_d and k in ( | ||||
| @@ -1251,12 +1268,14 @@ def lua_folder_set_attr(uid, k, v): | ||||
| #==================================================================# | ||||
| #  Get the "Amount to Generate" | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_get_genamt(): | ||||
|     return vars.genamt | ||||
|  | ||||
| #==================================================================# | ||||
| #  Set the "Amount to Generate" | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_set_genamt(genamt): | ||||
|     assert vars.lua_koboldbridge.userstate != "genmod" and type(genamt) in (int, float) and genamt >= 0 | ||||
|     print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set genamt to {int(genamt)}" + colors.END) | ||||
| @@ -1265,12 +1284,14 @@ def lua_set_genamt(genamt): | ||||
| #==================================================================# | ||||
| #  Get the "Gens Per Action" | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_get_numseqs(): | ||||
|     return vars.numseqs | ||||
|  | ||||
| #==================================================================# | ||||
| #  Set the "Gens Per Action" | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_set_numseqs(numseqs): | ||||
|     assert type(numseqs) in (int, float) and numseqs >= 1 | ||||
|     print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set numseqs to {int(numseqs)}" + colors.END) | ||||
| @@ -1279,6 +1300,7 @@ def lua_set_numseqs(numseqs): | ||||
| #==================================================================# | ||||
| #  Check if a setting exists with the given name | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_has_setting(setting): | ||||
|     return setting in ( | ||||
|         "anotedepth", | ||||
| @@ -1326,6 +1348,7 @@ def lua_has_setting(setting): | ||||
| #==================================================================# | ||||
| #  Return the setting with the given name if it exists | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_get_setting(setting): | ||||
|     if(setting in ("settemp", "temp")): return vars.temp | ||||
|     if(setting in ("settopp", "topp", "top_p")): return vars.top_p | ||||
| @@ -1350,6 +1373,7 @@ def lua_get_setting(setting): | ||||
| #==================================================================# | ||||
| #  Set the setting with the given name if it exists | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_set_setting(setting, v): | ||||
|     actual_type = type(lua_get_setting(setting)) | ||||
|     assert v is not None and (actual_type is type(v) or (actual_type is int and type(v) is float)) | ||||
| @@ -1380,12 +1404,14 @@ def lua_set_setting(setting, v): | ||||
| #==================================================================# | ||||
| #  Get contents of memory | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_get_memory(): | ||||
|     return vars.memory | ||||
|  | ||||
| #==================================================================# | ||||
| #  Set contents of memory | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_set_memory(m): | ||||
|     assert type(m) is str | ||||
|     vars.memory = m | ||||
| @@ -1393,12 +1419,14 @@ def lua_set_memory(m): | ||||
| #==================================================================# | ||||
| #  Get contents of author's note | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_get_authorsnote(): | ||||
|     return vars.authornote | ||||
|  | ||||
| #==================================================================# | ||||
| #  Set contents of author's note | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_set_authorsnote(m): | ||||
|     assert type(m) is str | ||||
|     vars.authornote = m | ||||
| @@ -1406,12 +1434,14 @@ def lua_set_authorsnote(m): | ||||
| #==================================================================# | ||||
| #  Get contents of author's note template | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_get_authorsnotetemplate(): | ||||
|     return vars.authornotetemplate | ||||
|  | ||||
| #==================================================================# | ||||
| #  Set contents of author's note template | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_set_authorsnotetemplate(m): | ||||
|     assert type(m) is str | ||||
|     vars.authornotetemplate = m | ||||
| @@ -1419,6 +1449,7 @@ def lua_set_authorsnotetemplate(m): | ||||
| #==================================================================# | ||||
| #  Save settings and send them to client | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_resend_settings(): | ||||
|     settingschanged() | ||||
|     refresh_settings() | ||||
| @@ -1426,6 +1457,7 @@ def lua_resend_settings(): | ||||
| #==================================================================# | ||||
| #  Set story chunk text and delete the chunk if the new chunk is empty | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_set_chunk(k, v): | ||||
|     assert type(k) in (int, None) and type(v) is str | ||||
|     assert k >= 0 | ||||
| @@ -1458,6 +1490,7 @@ def lua_set_chunk(k, v): | ||||
| #==================================================================# | ||||
| #  Get model type as "gpt-2-xl", "gpt-neo-2.7B", etc. | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_get_modeltype(): | ||||
|     if(vars.noai): | ||||
|         return "readonly" | ||||
| @@ -1486,6 +1519,7 @@ def lua_get_modeltype(): | ||||
| #==================================================================# | ||||
| #  Get model backend as "transformers" or "mtj" | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_get_modelbackend(): | ||||
|     if(vars.noai): | ||||
|         return "readonly" | ||||
| @@ -1498,6 +1532,7 @@ def lua_get_modelbackend(): | ||||
| #==================================================================# | ||||
| #  Check whether model is loaded from a custom path | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_is_custommodel(): | ||||
|     return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ") | ||||
|  | ||||
| @@ -1558,35 +1593,10 @@ bridged = { | ||||
|     "userscript_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"), | ||||
|     "config_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"), | ||||
|     "lib_paths": vars.lua_state.table(os.path.join(os.path.dirname(os.path.realpath(__file__)), "lualibs"), os.path.join(os.path.dirname(os.path.realpath(__file__)), "extern", "lualibs")), | ||||
|     "load_callback": load_callback, | ||||
|     "print": lua_print, | ||||
|     "warn": lua_warn, | ||||
|     "decode": lua_decode, | ||||
|     "encode": lua_encode, | ||||
|     "get_attr": lua_get_attr, | ||||
|     "set_attr": lua_set_attr, | ||||
|     "folder_get_attr": lua_folder_get_attr, | ||||
|     "folder_set_attr": lua_folder_set_attr, | ||||
|     "get_genamt": lua_get_genamt, | ||||
|     "set_genamt": lua_set_genamt, | ||||
|     "get_memory": lua_get_memory, | ||||
|     "set_memory": lua_set_memory, | ||||
|     "get_authorsnote": lua_get_authorsnote, | ||||
|     "set_authorsnote": lua_set_authorsnote, | ||||
|     "get_authorsnote": lua_get_authorsnotetemplate, | ||||
|     "set_authorsnote": lua_set_authorsnotetemplate, | ||||
|     "compute_context": lua_compute_context, | ||||
|     "get_numseqs": lua_get_numseqs, | ||||
|     "set_numseqs": lua_set_numseqs, | ||||
|     "has_setting": lua_has_setting, | ||||
|     "get_setting": lua_get_setting, | ||||
|     "set_setting": lua_set_setting, | ||||
|     "set_chunk": lua_set_chunk, | ||||
|     "get_modeltype": lua_get_modeltype, | ||||
|     "get_modelbackend": lua_get_modelbackend, | ||||
|     "is_custommodel": lua_is_custommodel, | ||||
|     "vars": vars, | ||||
| } | ||||
| for kwarg in _bridged: | ||||
|     bridged[kwarg] = _bridged[kwarg] | ||||
| try: | ||||
|     vars.lua_kobold, vars.lua_koboldcore, vars.lua_koboldbridge = vars.lua_state.globals().dofile(os.path.join(os.path.dirname(os.path.realpath(__file__)), "bridge.lua"))( | ||||
|         vars.lua_state.globals().python, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Gnome Ann
					Gnome Ann