Add Lua API for editing logits during generation

TPU backend not supported yet.
This commit is contained in:
Gnome Ann
2021-12-12 16:18:45 -05:00
parent e2c3ac041b
commit ceabd2ef7b
2 changed files with 110 additions and 11 deletions

View File

@ -145,6 +145,7 @@ return function(_python, _bridged)
---@field modelbackend "'readonly'"|"'api'"|"'transformers'"|"'mtj'"
---@field is_custommodel boolean
---@field custmodpth string
---@field logits table<integer, table<integer, number>>
local kobold = setmetatable({}, metawrapper)
local KoboldLib_mt = setmetatable({}, metawrapper)
local KoboldLib_getters = setmetatable({}, metawrapper)
@ -200,6 +201,8 @@ return function(_python, _bridged)
koboldbridge.resend_settings_required = false
koboldbridge.generating = true
koboldbridge.userstate = "inmod"
koboldbridge.logits = {}
koboldbridge.vocab_size = 0
---@return nil
local function maybe_require_regeneration()
@ -883,13 +886,13 @@ return function(_python, _bridged)
end
---@param t KoboldLib
---@return string
---@return boolean
function KoboldLib_getters.is_custommodel(t)
return bridged.is_custommodel()
end
---@param t KoboldLib
---@param v string
---@param v boolean
function KoboldLib_setters.is_custommodel(t, v)
error("`KoboldLib.is_custommodel` is a read-only attribute")
end
@ -907,6 +910,65 @@ return function(_python, _bridged)
end
--==========================================================================
-- Userscript API: Logit Warping
--==========================================================================
---@param t KoboldLib
---@return integer
function KoboldLib_getters.logits_rows(t)
local backend = kobold.modelbackend
if backend == "readonly" or backend == "api" then
return 0
end
return kobold.settings.numseqs
end
---@param t KoboldLib
---@return integer
function KoboldLib_setters.logits_rows(t)
error("`KoboldLib.logits_rows` is a read-only attribute")
end
---@param t KoboldLib
---@return integer
function KoboldLib_getters.logits_cols(t)
local backend = kobold.modelbackend
if backend == "readonly" or backend == "api" then
return 0
end
return math.tointeger(koboldbridge.vocab_size)
end
---@param t KoboldLib
---@return integer
function KoboldLib_setters.logits_cols(t)
error("`KoboldLib.logits_cols` is a read-only attribute")
end
---@param t KoboldLib
---@return table<integer, table<integer, number>>
function KoboldLib_getters.logits(t)
if koboldbridge.userstate ~= "genmod" then
return
end
return koboldbridge.logits
end
---@param t KoboldLib
---@param v table<integer, table<integer, number>>
function KoboldLib_setters.logits(t, v)
if koboldbridge.userstate ~= "genmod" then
error("Cannot write to `KoboldLib.logits` from outside of a generation modifer")
return
elseif type(v) ~= "table" then
error("`KoboldLib.logits` must be a 2D list (table) of numbers; you attempted to set it to a " .. type(v))
return
end
koboldbridge.logits = v
end
--==========================================================================
-- Userscript API: Utilities
--==========================================================================
@ -1386,6 +1448,16 @@ return function(_python, _bridged)
koboldbridge.userstate = "genmod"
if koboldbridge.genmod ~= nil then
r = koboldbridge.genmod()
setmetatable(koboldbridge.logits, nil)
for kr, vr in old_next, koboldbridge.logits, nil do
setmetatable(vr, nil)
for kc, vc in old_next, vr, nil do
if type(vc) ~= "number" then
error("`kobold.logits` must be a 2D table of numbers, but found a non-number element at row " .. kr .. ", column " .. kc)
return r
end
end
end
end
return r
end