Add Lua API for editing logits during generation
TPU backend not supported yet.
This commit is contained in:
parent
e2c3ac041b
commit
ceabd2ef7b
45
aiserver.py
45
aiserver.py
|
@ -104,6 +104,7 @@ class vars:
|
|||
lua_koboldbridge = None # `koboldbridge` from bridge.lua
|
||||
lua_kobold = None # `kobold` from` bridge.lua
|
||||
lua_koboldcore = None # `koboldcore` from bridge.lua
|
||||
lua_warper = None # Transformers logits warper controllable from Lua
|
||||
# badwords = [] # Array of str/chr values that should be removed from output
|
||||
badwordsids = [[13460], [6880], [50256], [42496], [4613], [17414], [22039], [16410], [27], [29], [38430], [37922], [15913], [24618], [28725], [58], [47175], [36937], [26700], [12878], [16471], [37981], [5218], [29795], [13412], [45160], [3693], [49778], [4211], [20598], [36475], [33409], [44167], [32406], [29847], [29342], [42669], [685], [25787], [7359], [3784], [5320], [33994], [33490], [34516], [43734], [17635], [24293], [9959], [23785], [21737], [28401], [18161], [26358], [32509], [1279], [38155], [18189], [26894], [6927], [14610], [23834], [11037], [14631], [26933], [46904], [22330], [25915], [47934], [38214], [1875], [14692], [41832], [13163], [25970], [29565], [44926], [19841], [37250], [49029], [9609], [44438], [16791], [17816], [30109], [41888], [47527], [42924], [23984], [49074], [33717], [31161], [49082], [30138], [31175], [12240], [14804], [7131], [26076], [33250], [3556], [38381], [36338], [32756], [46581], [17912], [49146]] # Tokenized array of badwords used to prevent AI artifacting
|
||||
deletewi = -1 # Temporary storage for index to delete
|
||||
|
@ -643,6 +644,37 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
class LuaLogitsWarper(LogitsWarper):
|
||||
|
||||
def __init__(self):
|
||||
self.regeneration_required = False
|
||||
self.halt = False
|
||||
pass
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
assert scores.ndim == 2
|
||||
self.regeneration_required = False
|
||||
self.halt = False
|
||||
scores_shape = scores.shape
|
||||
scores_list = self.scores.tolist()
|
||||
vars.lua_koboldbridge.logits = vars.lua_state.table()
|
||||
for r, row in enumerate(scores_list):
|
||||
vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row)
|
||||
vars.lua_koboldbridge.vocab_size = scores_shape[-1]
|
||||
execute_genmod()
|
||||
scores = torch.tensor(
|
||||
tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()),
|
||||
device=scores.device,
|
||||
dtype=scores.dtype,
|
||||
)
|
||||
assert scores.shape == scores_shape
|
||||
if(vars.lua_koboldbridge.regeneration_required):
|
||||
vars.lua_koboldbridge.regeneration_required = False
|
||||
self.regeneration_required = True
|
||||
if(not vars.lua_koboldbridge.generating):
|
||||
self.halt = True
|
||||
return scores
|
||||
|
||||
def new_get_logits_warper(
|
||||
top_k: int = None,
|
||||
|
@ -660,6 +692,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
warper_list.append(TailFreeLogitsWarper(tfs=tfs, min_tokens_to_keep=1 + (beams > 1)))
|
||||
if(temp is not None and temp != 1.0):
|
||||
warper_list.append(TemperatureLogitsWarper(temperature=temp))
|
||||
vars.lua_warper = LuaLogitsWarper()
|
||||
warper_list.append(vars.lua_warper)
|
||||
return warper_list
|
||||
|
||||
def new_sample(self, *args, **kwargs):
|
||||
|
@ -705,15 +739,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
) -> bool:
|
||||
assert input_ids.ndim == 2
|
||||
assert len(self.excluded_world_info) == input_ids.shape[0]
|
||||
self.regeneration_required = False
|
||||
self.halt = False
|
||||
|
||||
execute_genmod()
|
||||
if(vars.lua_koboldbridge.regeneration_required):
|
||||
vars.lua_koboldbridge.regeneration_required = False
|
||||
self.regeneration_required = True
|
||||
if(not vars.lua_koboldbridge.generating):
|
||||
self.halt = True
|
||||
self.regeneration_required = vars.lua_warper.regeneration_required
|
||||
self.halt = vars.lua_warper.halt
|
||||
|
||||
if(not vars.dynamicscan):
|
||||
return False
|
||||
|
|
76
bridge.lua
76
bridge.lua
|
@ -145,6 +145,7 @@ return function(_python, _bridged)
|
|||
---@field modelbackend "'readonly'"|"'api'"|"'transformers'"|"'mtj'"
|
||||
---@field is_custommodel boolean
|
||||
---@field custmodpth string
|
||||
---@field logits table<integer, table<integer, number>>
|
||||
local kobold = setmetatable({}, metawrapper)
|
||||
local KoboldLib_mt = setmetatable({}, metawrapper)
|
||||
local KoboldLib_getters = setmetatable({}, metawrapper)
|
||||
|
@ -200,6 +201,8 @@ return function(_python, _bridged)
|
|||
koboldbridge.resend_settings_required = false
|
||||
koboldbridge.generating = true
|
||||
koboldbridge.userstate = "inmod"
|
||||
koboldbridge.logits = {}
|
||||
koboldbridge.vocab_size = 0
|
||||
|
||||
---@return nil
|
||||
local function maybe_require_regeneration()
|
||||
|
@ -883,13 +886,13 @@ return function(_python, _bridged)
|
|||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@return string
|
||||
---@return boolean
|
||||
function KoboldLib_getters.is_custommodel(t)
|
||||
return bridged.is_custommodel()
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@param v string
|
||||
---@param v boolean
|
||||
function KoboldLib_setters.is_custommodel(t, v)
|
||||
error("`KoboldLib.is_custommodel` is a read-only attribute")
|
||||
end
|
||||
|
@ -907,6 +910,65 @@ return function(_python, _bridged)
|
|||
end
|
||||
|
||||
|
||||
--==========================================================================
|
||||
-- Userscript API: Logit Warping
|
||||
--==========================================================================
|
||||
|
||||
---@param t KoboldLib
|
||||
---@return integer
|
||||
function KoboldLib_getters.logits_rows(t)
|
||||
local backend = kobold.modelbackend
|
||||
if backend == "readonly" or backend == "api" then
|
||||
return 0
|
||||
end
|
||||
return kobold.settings.numseqs
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@return integer
|
||||
function KoboldLib_setters.logits_rows(t)
|
||||
error("`KoboldLib.logits_rows` is a read-only attribute")
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@return integer
|
||||
function KoboldLib_getters.logits_cols(t)
|
||||
local backend = kobold.modelbackend
|
||||
if backend == "readonly" or backend == "api" then
|
||||
return 0
|
||||
end
|
||||
return math.tointeger(koboldbridge.vocab_size)
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@return integer
|
||||
function KoboldLib_setters.logits_cols(t)
|
||||
error("`KoboldLib.logits_cols` is a read-only attribute")
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@return table<integer, table<integer, number>>
|
||||
function KoboldLib_getters.logits(t)
|
||||
if koboldbridge.userstate ~= "genmod" then
|
||||
return
|
||||
end
|
||||
return koboldbridge.logits
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@param v table<integer, table<integer, number>>
|
||||
function KoboldLib_setters.logits(t, v)
|
||||
if koboldbridge.userstate ~= "genmod" then
|
||||
error("Cannot write to `KoboldLib.logits` from outside of a generation modifer")
|
||||
return
|
||||
elseif type(v) ~= "table" then
|
||||
error("`KoboldLib.logits` must be a 2D list (table) of numbers; you attempted to set it to a " .. type(v))
|
||||
return
|
||||
end
|
||||
koboldbridge.logits = v
|
||||
end
|
||||
|
||||
|
||||
--==========================================================================
|
||||
-- Userscript API: Utilities
|
||||
--==========================================================================
|
||||
|
@ -1386,6 +1448,16 @@ return function(_python, _bridged)
|
|||
koboldbridge.userstate = "genmod"
|
||||
if koboldbridge.genmod ~= nil then
|
||||
r = koboldbridge.genmod()
|
||||
setmetatable(koboldbridge.logits, nil)
|
||||
for kr, vr in old_next, koboldbridge.logits, nil do
|
||||
setmetatable(vr, nil)
|
||||
for kc, vc in old_next, vr, nil do
|
||||
if type(vc) ~= "number" then
|
||||
error("`kobold.logits` must be a 2D table of numbers, but found a non-number element at row " .. kr .. ", column " .. kc)
|
||||
return r
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
return r
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue