Add Lua API for reading model information

This commit is contained in:
Gnome Ann 2021-12-12 12:09:59 -05:00
parent 09df371d99
commit d76dd35791
2 changed files with 119 additions and 0 deletions

View File

@ -1141,6 +1141,52 @@ def lua_set_chunk(k, v):
print(colors.PURPLE + f"[USERPLACEHOLDER] edited story chunk {k}" + colors.END) print(colors.PURPLE + f"[USERPLACEHOLDER] edited story chunk {k}" + colors.END)
inlineedit(k, v) inlineedit(k, v)
#==================================================================#
# Get model type as "gpt-2-xl", "gpt-neo-2.7B", etc.
#==================================================================#
def lua_get_modeltype():
if(vars.noai):
return "readonly"
if(vars.model in ("Colab", "OAI", "InferKit")):
return "api"
if(vars.model in ("GPT2Custom", "NeoCustom")):
hidden_size = get_hidden_size_from_model(model)
if(vars.model in ("gpt2",) or (vars.model == "GPT2Custom" and hidden_size == 768)):
return "gpt2"
if(vars.model in ("gpt2-medium",) or (vars.model == "GPT2Custom" and hidden_size == 1024)):
return "gpt2-medium"
if(vars.model in ("gpt2-large",) or (vars.model == "GPT2Custom" and hidden_size == 1280)):
return "gpt2-large"
if(vars.model in ("gpt2-xl",) or (vars.model == "GPT2Custom" and hidden_size == 1600)):
return "gpt2-xl"
if(vars.model == "NeoCustom" and hidden_size == 768):
return "gpt-neo-125M"
if(vars.model in ("EleutherAI/gpt-neo-1.3M",) or (vars.model == "NeoCustom" and hidden_size == 2048)):
return "gpt-neo-1.3M"
if(vars.model in ("EleutherAI/gpt-neo-2.7B",) or (vars.model == "NeoCustom" and hidden_size == 2560)):
return "gpt-neo-2.7B"
if(vars.model in ("EleutherAI/gpt-j-6B",) or (vars.model == "NeoCustom" and hidden_size == 4096) or (vars.model == "TPUMeshTransformerGPTJ" and tpu_mtj_backend.params["d_model"] == 4096)):
return "gpt-j-6B"
return "unknown"
#==================================================================#
# Get model backend as "transformers" or "mtj"
#==================================================================#
def lua_get_modelbackend():
if(vars.noai):
return "readonly"
if(vars.model in ("Colab", "OAI", "InferKit")):
return "api"
if(vars.model in ("TPUMeshTransformerGPTJ",)):
return "mtj"
return "transformers"
#==================================================================#
# Check whether model is loaded from a custom path
#==================================================================#
def lua_is_custommodel():
return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ")
#==================================================================# #==================================================================#
# #
#==================================================================# #==================================================================#
@ -1185,6 +1231,9 @@ bridged = {
"set_setting": lua_set_setting, "set_setting": lua_set_setting,
"resend_settings": lua_resend_settings, "resend_settings": lua_resend_settings,
"set_chunk": lua_set_chunk, "set_chunk": lua_set_chunk,
"get_modeltype": lua_get_modeltype,
"get_modelbackend": lua_get_modelbackend,
"is_custommodel": lua_is_custommodel,
"vars": vars, "vars": vars,
} }
try: try:

View File

@ -140,6 +140,11 @@ return function(_python, _bridged)
---@class KoboldLib ---@class KoboldLib
---@field memory string ---@field memory string
---@field submission string ---@field submission string
---@field model string
---@field modeltype "'readonly'"|"'api'"|"'unknown'"|"'gpt2'"|"'gpt2-medium'"|"'gpt2-large'"|"'gpt2-xl'"|"'gpt-neo-125M'"|"'gpt-neo-1.3B'"|"'gpt-neo-2.7B'"|"'gpt-j-6B'"
---@field modelbackend "'readonly'"|"'api'"|"'transformers'"|"'mtj'"
---@field is_custommodel boolean
---@field custmodpth string
local kobold = setmetatable({}, metawrapper) local kobold = setmetatable({}, metawrapper)
local KoboldLib_mt = setmetatable({}, metawrapper) local KoboldLib_mt = setmetatable({}, metawrapper)
local KoboldLib_getters = setmetatable({}, metawrapper) local KoboldLib_getters = setmetatable({}, metawrapper)
@ -838,6 +843,71 @@ return function(_python, _bridged)
end end
--==========================================================================
-- Userscript API: Model information
--==========================================================================
---@param t KoboldLib
---@return string
function KoboldLib_getters.modeltype(t)
return bridged.get_modeltype()
end
---@param t KoboldLib
---@param v string
function KoboldLib_setters.modeltype(t, v)
error("`KoboldLib.modeltype` is a read-only attribute")
end
---@param t KoboldLib
---@return string
function KoboldLib_getters.model(t)
return bridged.vars.model
end
---@param t KoboldLib
---@param v string
function KoboldLib_setters.model(t, v)
error("`KoboldLib.model` is a read-only attribute")
end
---@param t KoboldLib
---@return string
function KoboldLib_getters.modelbackend(t)
return bridged.get_modelbackend()
end
---@param t KoboldLib
---@param v string
function KoboldLib_setters.modelbackend(t, v)
error("`KoboldLib.modelbackend` is a read-only attribute")
end
---@param t KoboldLib
---@return string
function KoboldLib_getters.is_custommodel(t)
return bridged.is_custommodel()
end
---@param t KoboldLib
---@param v string
function KoboldLib_setters.is_custommodel(t, v)
error("`KoboldLib.is_custommodel` is a read-only attribute")
end
---@param t KoboldLib
---@return string
function KoboldLib_getters.custmodpth(t)
return bridged.vars.custmodpth
end
---@param t KoboldLib
---@param v string
function KoboldLib_setters.custmodpth(t, v)
error("`KoboldLib.custmodpth` is a read-only attribute")
end
--========================================================================== --==========================================================================
-- Userscript API: Utilities -- Userscript API: Utilities
--========================================================================== --==========================================================================