Add Lua API for reading model information
This commit is contained in:
parent
09df371d99
commit
d76dd35791
49
aiserver.py
49
aiserver.py
|
@ -1141,6 +1141,52 @@ def lua_set_chunk(k, v):
|
|||
print(colors.PURPLE + f"[USERPLACEHOLDER] edited story chunk {k}" + colors.END)
|
||||
inlineedit(k, v)
|
||||
|
||||
#==================================================================#
|
||||
# Get model type as "gpt-2-xl", "gpt-neo-2.7B", etc.
|
||||
#==================================================================#
|
||||
def lua_get_modeltype():
|
||||
if(vars.noai):
|
||||
return "readonly"
|
||||
if(vars.model in ("Colab", "OAI", "InferKit")):
|
||||
return "api"
|
||||
if(vars.model in ("GPT2Custom", "NeoCustom")):
|
||||
hidden_size = get_hidden_size_from_model(model)
|
||||
if(vars.model in ("gpt2",) or (vars.model == "GPT2Custom" and hidden_size == 768)):
|
||||
return "gpt2"
|
||||
if(vars.model in ("gpt2-medium",) or (vars.model == "GPT2Custom" and hidden_size == 1024)):
|
||||
return "gpt2-medium"
|
||||
if(vars.model in ("gpt2-large",) or (vars.model == "GPT2Custom" and hidden_size == 1280)):
|
||||
return "gpt2-large"
|
||||
if(vars.model in ("gpt2-xl",) or (vars.model == "GPT2Custom" and hidden_size == 1600)):
|
||||
return "gpt2-xl"
|
||||
if(vars.model == "NeoCustom" and hidden_size == 768):
|
||||
return "gpt-neo-125M"
|
||||
if(vars.model in ("EleutherAI/gpt-neo-1.3M",) or (vars.model == "NeoCustom" and hidden_size == 2048)):
|
||||
return "gpt-neo-1.3M"
|
||||
if(vars.model in ("EleutherAI/gpt-neo-2.7B",) or (vars.model == "NeoCustom" and hidden_size == 2560)):
|
||||
return "gpt-neo-2.7B"
|
||||
if(vars.model in ("EleutherAI/gpt-j-6B",) or (vars.model == "NeoCustom" and hidden_size == 4096) or (vars.model == "TPUMeshTransformerGPTJ" and tpu_mtj_backend.params["d_model"] == 4096)):
|
||||
return "gpt-j-6B"
|
||||
return "unknown"
|
||||
|
||||
#==================================================================#
|
||||
# Get model backend as "transformers" or "mtj"
|
||||
#==================================================================#
|
||||
def lua_get_modelbackend():
|
||||
if(vars.noai):
|
||||
return "readonly"
|
||||
if(vars.model in ("Colab", "OAI", "InferKit")):
|
||||
return "api"
|
||||
if(vars.model in ("TPUMeshTransformerGPTJ",)):
|
||||
return "mtj"
|
||||
return "transformers"
|
||||
|
||||
#==================================================================#
|
||||
# Check whether model is loaded from a custom path
|
||||
#==================================================================#
|
||||
def lua_is_custommodel():
|
||||
return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ")
|
||||
|
||||
#==================================================================#
|
||||
#
|
||||
#==================================================================#
|
||||
|
@ -1185,6 +1231,9 @@ bridged = {
|
|||
"set_setting": lua_set_setting,
|
||||
"resend_settings": lua_resend_settings,
|
||||
"set_chunk": lua_set_chunk,
|
||||
"get_modeltype": lua_get_modeltype,
|
||||
"get_modelbackend": lua_get_modelbackend,
|
||||
"is_custommodel": lua_is_custommodel,
|
||||
"vars": vars,
|
||||
}
|
||||
try:
|
||||
|
|
70
bridge.lua
70
bridge.lua
|
@ -140,6 +140,11 @@ return function(_python, _bridged)
|
|||
---@class KoboldLib
|
||||
---@field memory string
|
||||
---@field submission string
|
||||
---@field model string
|
||||
---@field modeltype "'readonly'"|"'api'"|"'unknown'"|"'gpt2'"|"'gpt2-medium'"|"'gpt2-large'"|"'gpt2-xl'"|"'gpt-neo-125M'"|"'gpt-neo-1.3B'"|"'gpt-neo-2.7B'"|"'gpt-j-6B'"
|
||||
---@field modelbackend "'readonly'"|"'api'"|"'transformers'"|"'mtj'"
|
||||
---@field is_custommodel boolean
|
||||
---@field custmodpth string
|
||||
local kobold = setmetatable({}, metawrapper)
|
||||
local KoboldLib_mt = setmetatable({}, metawrapper)
|
||||
local KoboldLib_getters = setmetatable({}, metawrapper)
|
||||
|
@ -838,6 +843,71 @@ return function(_python, _bridged)
|
|||
end
|
||||
|
||||
|
||||
--==========================================================================
|
||||
-- Userscript API: Model information
|
||||
--==========================================================================
|
||||
|
||||
---@param t KoboldLib
|
||||
---@return string
|
||||
function KoboldLib_getters.modeltype(t)
|
||||
return bridged.get_modeltype()
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@param v string
|
||||
function KoboldLib_setters.modeltype(t, v)
|
||||
error("`KoboldLib.modeltype` is a read-only attribute")
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@return string
|
||||
function KoboldLib_getters.model(t)
|
||||
return bridged.vars.model
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@param v string
|
||||
function KoboldLib_setters.model(t, v)
|
||||
error("`KoboldLib.model` is a read-only attribute")
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@return string
|
||||
function KoboldLib_getters.modelbackend(t)
|
||||
return bridged.get_modelbackend()
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@param v string
|
||||
function KoboldLib_setters.modelbackend(t, v)
|
||||
error("`KoboldLib.modelbackend` is a read-only attribute")
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@return string
|
||||
function KoboldLib_getters.is_custommodel(t)
|
||||
return bridged.is_custommodel()
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@param v string
|
||||
function KoboldLib_setters.is_custommodel(t, v)
|
||||
error("`KoboldLib.is_custommodel` is a read-only attribute")
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@return string
|
||||
function KoboldLib_getters.custmodpth(t)
|
||||
return bridged.vars.custmodpth
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@param v string
|
||||
function KoboldLib_setters.custmodpth(t, v)
|
||||
error("`KoboldLib.custmodpth` is a read-only attribute")
|
||||
end
|
||||
|
||||
|
||||
--==========================================================================
|
||||
-- Userscript API: Utilities
|
||||
--==========================================================================
|
||||
|
|
Loading…
Reference in New Issue