Add Lua API for reading model information

This commit is contained in:
Gnome Ann
2021-12-12 12:09:59 -05:00
parent 09df371d99
commit d76dd35791
2 changed files with 119 additions and 0 deletions

View File

@@ -1141,6 +1141,52 @@ def lua_set_chunk(k, v):
print(colors.PURPLE + f"[USERPLACEHOLDER] edited story chunk {k}" + colors.END)
inlineedit(k, v)
#==================================================================#
# Get model type as "gpt-2-xl", "gpt-neo-2.7B", etc.
#==================================================================#
def lua_get_modeltype():
if(vars.noai):
return "readonly"
if(vars.model in ("Colab", "OAI", "InferKit")):
return "api"
if(vars.model in ("GPT2Custom", "NeoCustom")):
hidden_size = get_hidden_size_from_model(model)
if(vars.model in ("gpt2",) or (vars.model == "GPT2Custom" and hidden_size == 768)):
return "gpt2"
if(vars.model in ("gpt2-medium",) or (vars.model == "GPT2Custom" and hidden_size == 1024)):
return "gpt2-medium"
if(vars.model in ("gpt2-large",) or (vars.model == "GPT2Custom" and hidden_size == 1280)):
return "gpt2-large"
if(vars.model in ("gpt2-xl",) or (vars.model == "GPT2Custom" and hidden_size == 1600)):
return "gpt2-xl"
if(vars.model == "NeoCustom" and hidden_size == 768):
return "gpt-neo-125M"
if(vars.model in ("EleutherAI/gpt-neo-1.3M",) or (vars.model == "NeoCustom" and hidden_size == 2048)):
return "gpt-neo-1.3M"
if(vars.model in ("EleutherAI/gpt-neo-2.7B",) or (vars.model == "NeoCustom" and hidden_size == 2560)):
return "gpt-neo-2.7B"
if(vars.model in ("EleutherAI/gpt-j-6B",) or (vars.model == "NeoCustom" and hidden_size == 4096) or (vars.model == "TPUMeshTransformerGPTJ" and tpu_mtj_backend.params["d_model"] == 4096)):
return "gpt-j-6B"
return "unknown"
#==================================================================#
# Get model backend as "transformers" or "mtj"
#==================================================================#
def lua_get_modelbackend():
if(vars.noai):
return "readonly"
if(vars.model in ("Colab", "OAI", "InferKit")):
return "api"
if(vars.model in ("TPUMeshTransformerGPTJ",)):
return "mtj"
return "transformers"
#==================================================================#
# Check whether model is loaded from a custom path
#==================================================================#
def lua_is_custommodel():
return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ")
#==================================================================#
#
#==================================================================#
@@ -1185,6 +1231,9 @@ bridged = {
"set_setting": lua_set_setting,
"resend_settings": lua_resend_settings,
"set_chunk": lua_set_chunk,
"get_modeltype": lua_get_modeltype,
"get_modelbackend": lua_get_modelbackend,
"is_custommodel": lua_is_custommodel,
"vars": vars,
}
try: