mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Add Lua API for editing logits during generation
TPU backend not supported yet.
This commit is contained in:
76
bridge.lua
76
bridge.lua
@ -145,6 +145,7 @@ return function(_python, _bridged)
|
||||
---@field modelbackend "'readonly'"|"'api'"|"'transformers'"|"'mtj'"
|
||||
---@field is_custommodel boolean
|
||||
---@field custmodpth string
|
||||
---@field logits table<integer, table<integer, number>>
|
||||
local kobold = setmetatable({}, metawrapper)
|
||||
local KoboldLib_mt = setmetatable({}, metawrapper)
|
||||
local KoboldLib_getters = setmetatable({}, metawrapper)
|
||||
@ -200,6 +201,8 @@ return function(_python, _bridged)
|
||||
koboldbridge.resend_settings_required = false
|
||||
koboldbridge.generating = true
|
||||
koboldbridge.userstate = "inmod"
|
||||
koboldbridge.logits = {}
|
||||
koboldbridge.vocab_size = 0
|
||||
|
||||
---@return nil
|
||||
local function maybe_require_regeneration()
|
||||
@ -883,13 +886,13 @@ return function(_python, _bridged)
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@return string
|
||||
---@return boolean
|
||||
function KoboldLib_getters.is_custommodel(t)
|
||||
return bridged.is_custommodel()
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@param v string
|
||||
---@param v boolean
|
||||
function KoboldLib_setters.is_custommodel(t, v)
|
||||
error("`KoboldLib.is_custommodel` is a read-only attribute")
|
||||
end
|
||||
@ -907,6 +910,65 @@ return function(_python, _bridged)
|
||||
end
|
||||
|
||||
|
||||
--==========================================================================
|
||||
-- Userscript API: Logit Warping
|
||||
--==========================================================================
|
||||
|
||||
---@param t KoboldLib
|
||||
---@return integer
|
||||
function KoboldLib_getters.logits_rows(t)
|
||||
local backend = kobold.modelbackend
|
||||
if backend == "readonly" or backend == "api" then
|
||||
return 0
|
||||
end
|
||||
return kobold.settings.numseqs
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@return integer
|
||||
function KoboldLib_setters.logits_rows(t)
|
||||
error("`KoboldLib.logits_rows` is a read-only attribute")
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@return integer
|
||||
function KoboldLib_getters.logits_cols(t)
|
||||
local backend = kobold.modelbackend
|
||||
if backend == "readonly" or backend == "api" then
|
||||
return 0
|
||||
end
|
||||
return math.tointeger(koboldbridge.vocab_size)
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@return integer
|
||||
function KoboldLib_setters.logits_cols(t)
|
||||
error("`KoboldLib.logits_cols` is a read-only attribute")
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@return table<integer, table<integer, number>>
|
||||
function KoboldLib_getters.logits(t)
|
||||
if koboldbridge.userstate ~= "genmod" then
|
||||
return
|
||||
end
|
||||
return koboldbridge.logits
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@param v table<integer, table<integer, number>>
|
||||
function KoboldLib_setters.logits(t, v)
|
||||
if koboldbridge.userstate ~= "genmod" then
|
||||
error("Cannot write to `KoboldLib.logits` from outside of a generation modifer")
|
||||
return
|
||||
elseif type(v) ~= "table" then
|
||||
error("`KoboldLib.logits` must be a 2D list (table) of numbers; you attempted to set it to a " .. type(v))
|
||||
return
|
||||
end
|
||||
koboldbridge.logits = v
|
||||
end
|
||||
|
||||
|
||||
--==========================================================================
|
||||
-- Userscript API: Utilities
|
||||
--==========================================================================
|
||||
@ -1386,6 +1448,16 @@ return function(_python, _bridged)
|
||||
koboldbridge.userstate = "genmod"
|
||||
if koboldbridge.genmod ~= nil then
|
||||
r = koboldbridge.genmod()
|
||||
setmetatable(koboldbridge.logits, nil)
|
||||
for kr, vr in old_next, koboldbridge.logits, nil do
|
||||
setmetatable(vr, nil)
|
||||
for kc, vc in old_next, vr, nil do
|
||||
if type(vc) ~= "number" then
|
||||
error("`kobold.logits` must be a 2D table of numbers, but found a non-number element at row " .. kr .. ", column " .. kc)
|
||||
return r
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
return r
|
||||
end
|
||||
|
Reference in New Issue
Block a user