mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Add Lua API for reading model information
This commit is contained in:
49
aiserver.py
49
aiserver.py
@@ -1141,6 +1141,52 @@ def lua_set_chunk(k, v):
|
||||
print(colors.PURPLE + f"[USERPLACEHOLDER] edited story chunk {k}" + colors.END)
|
||||
inlineedit(k, v)
|
||||
|
||||
#==================================================================#
|
||||
# Get model type as "gpt-2-xl", "gpt-neo-2.7B", etc.
|
||||
#==================================================================#
|
||||
def lua_get_modeltype():
|
||||
if(vars.noai):
|
||||
return "readonly"
|
||||
if(vars.model in ("Colab", "OAI", "InferKit")):
|
||||
return "api"
|
||||
if(vars.model in ("GPT2Custom", "NeoCustom")):
|
||||
hidden_size = get_hidden_size_from_model(model)
|
||||
if(vars.model in ("gpt2",) or (vars.model == "GPT2Custom" and hidden_size == 768)):
|
||||
return "gpt2"
|
||||
if(vars.model in ("gpt2-medium",) or (vars.model == "GPT2Custom" and hidden_size == 1024)):
|
||||
return "gpt2-medium"
|
||||
if(vars.model in ("gpt2-large",) or (vars.model == "GPT2Custom" and hidden_size == 1280)):
|
||||
return "gpt2-large"
|
||||
if(vars.model in ("gpt2-xl",) or (vars.model == "GPT2Custom" and hidden_size == 1600)):
|
||||
return "gpt2-xl"
|
||||
if(vars.model == "NeoCustom" and hidden_size == 768):
|
||||
return "gpt-neo-125M"
|
||||
if(vars.model in ("EleutherAI/gpt-neo-1.3M",) or (vars.model == "NeoCustom" and hidden_size == 2048)):
|
||||
return "gpt-neo-1.3M"
|
||||
if(vars.model in ("EleutherAI/gpt-neo-2.7B",) or (vars.model == "NeoCustom" and hidden_size == 2560)):
|
||||
return "gpt-neo-2.7B"
|
||||
if(vars.model in ("EleutherAI/gpt-j-6B",) or (vars.model == "NeoCustom" and hidden_size == 4096) or (vars.model == "TPUMeshTransformerGPTJ" and tpu_mtj_backend.params["d_model"] == 4096)):
|
||||
return "gpt-j-6B"
|
||||
return "unknown"
|
||||
|
||||
#==================================================================#
|
||||
# Get model backend as "transformers" or "mtj"
|
||||
#==================================================================#
|
||||
def lua_get_modelbackend():
|
||||
if(vars.noai):
|
||||
return "readonly"
|
||||
if(vars.model in ("Colab", "OAI", "InferKit")):
|
||||
return "api"
|
||||
if(vars.model in ("TPUMeshTransformerGPTJ",)):
|
||||
return "mtj"
|
||||
return "transformers"
|
||||
|
||||
#==================================================================#
|
||||
# Check whether model is loaded from a custom path
|
||||
#==================================================================#
|
||||
def lua_is_custommodel():
|
||||
return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ")
|
||||
|
||||
#==================================================================#
|
||||
#
|
||||
#==================================================================#
|
||||
@@ -1185,6 +1231,9 @@ bridged = {
|
||||
"set_setting": lua_set_setting,
|
||||
"resend_settings": lua_resend_settings,
|
||||
"set_chunk": lua_set_chunk,
|
||||
"get_modeltype": lua_get_modeltype,
|
||||
"get_modelbackend": lua_get_modelbackend,
|
||||
"is_custommodel": lua_is_custommodel,
|
||||
"vars": vars,
|
||||
}
|
||||
try:
|
||||
|
Reference in New Issue
Block a user