Enable generation modifiers for transformers backend only

This commit is contained in:
Gnome Ann 2021-12-11 16:28:25 -05:00
parent 1111408cc2
commit f8aa578f41
2 changed files with 27 additions and 19 deletions

View File

@ -690,7 +690,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
excluded_world_info: List[Set], excluded_world_info: List[Set],
head_length: int, head_length: int,
): ):
self.any_new_entries = False self.regeneration_required = False
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.excluded_world_info = excluded_world_info self.excluded_world_info = excluded_world_info
self.head_length = head_length self.head_length = head_length
@ -702,7 +702,13 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
) -> bool: ) -> bool:
assert input_ids.ndim == 2 assert input_ids.ndim == 2
assert len(self.excluded_world_info) == input_ids.shape[0] assert len(self.excluded_world_info) == input_ids.shape[0]
self.any_new_entries = False self.regeneration_required = False
vars.lua_koboldbridge.genmod()
if(vars.lua_koboldbridge.regeneration_required):
vars.lua_koboldbridge.regeneration_required = False
self.regeneration_required = True
if(not vars.dynamicscan): if(not vars.dynamicscan):
return False return False
tail = input_ids[..., self.head_length:] tail = input_ids[..., self.head_length:]
@ -711,9 +717,9 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
_, found = checkworldinfo(decoded, force_use_txt=True) _, found = checkworldinfo(decoded, force_use_txt=True)
found -= self.excluded_world_info[i] found -= self.excluded_world_info[i]
if(len(found) != 0): if(len(found) != 0):
self.any_new_entries = True self.regeneration_required = True
break break
return self.any_new_entries return self.regeneration_required
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
def new_get_stopping_criteria(self, *args, **kwargs): def new_get_stopping_criteria(self, *args, **kwargs):
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs) stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
@ -1910,7 +1916,7 @@ def generate(txt, minimum, maximum, found_entries=None):
num_return_sequences=numseqs num_return_sequences=numseqs
) )
already_generated += len(genout[0]) - len(gen_in[0]) already_generated += len(genout[0]) - len(gen_in[0])
if(not model.kai_scanner.any_new_entries): if(not model.kai_scanner.regeneration_required):
break break
assert genout.ndim >= 2 assert genout.ndim >= 2
assert genout.shape[0] == vars.numseqs assert genout.shape[0] == vars.numseqs

View File

@ -180,15 +180,23 @@ return function(_python, _bridged)
---@class KoboldBridgeLib ---@class KoboldBridgeLib
local koboldbridge = setmetatable({}, metawrapper) local koboldbridge = setmetatable({}, metawrapper)
koboldbridge.genmod_comparison_context = nil
koboldbridge.regeneration_required = false
koboldbridge.generating = true
koboldbridge.userstate = "inmod"
---@return nil
local function maybe_require_regeneration()
if koboldbridge.userstate == "genmod" and koboldbridge.genmod_comparison_context == nil then
koboldbridge.regeneration_required = true
end
end
--========================================================================== --==========================================================================
-- Userscript API: World Info -- Userscript API: World Info
--========================================================================== --==========================================================================
local genmod_comparison_context = nil
koboldbridge.generating = true
koboldbridge.userstate = "inmod"
local fields = setmetatable({}, metawrapper) local fields = setmetatable({}, metawrapper)
---@param t KoboldWorldInfoEntry|KoboldWorldInfoFolder|KoboldWorldInfo|KoboldWorldInfoFolderSelector ---@param t KoboldWorldInfoEntry|KoboldWorldInfoFolder|KoboldWorldInfo|KoboldWorldInfoFolderSelector
@ -201,13 +209,6 @@ return function(_python, _bridged)
return true return true
end end
---@return nil
local function maybe_save_genmod_comparison_context()
if koboldbridge.userstate == "genmod" and genmod_comparison_context == nil then
genmod_comparison_context = kobold.worldinfo:compute_context()
end
end
---------------------------------------------------------------------------- ----------------------------------------------------------------------------
@ -306,7 +307,7 @@ return function(_python, _bridged)
return return
else else
if k ~= "comment" then if k ~= "comment" then
maybe_save_genmod_comparison_context() maybe_require_regeneration()
end end
bridged.set_attr(t.uid, k, v) bridged.set_attr(t.uid, k, v)
return t return t
@ -589,6 +590,7 @@ return function(_python, _bridged)
---@param t KoboldSettings_base ---@param t KoboldSettings_base
function KoboldSettings_mt.__newindex(t, k, v) function KoboldSettings_mt.__newindex(t, k, v)
if k == "gen_len" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 0 then if k == "gen_len" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 0 then
maybe_require_regeneration()
bridged.set_gen_len(v) bridged.set_gen_len(v)
elseif k == "numseqs" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 1 then elseif k == "numseqs" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 1 then
if koboldbridge.userstate == "genmod" then if koboldbridge.userstate == "genmod" then
@ -598,7 +600,7 @@ return function(_python, _bridged)
bridged.set_numseqs(v) bridged.set_numseqs(v)
elseif type(k) == "string" and bridged.has_setting(k) and type(v) == type(bridged.get_setting(k)) then elseif type(k) == "string" and bridged.has_setting(k) and type(v) == type(bridged.get_setting(k)) then
if k == "settknmax" or k == "anotedepth" or k == "setwidepth" or k == "setuseprompt" then if k == "settknmax" or k == "anotedepth" or k == "setwidepth" or k == "setuseprompt" then
maybe_save_genmod_comparison_context() maybe_require_regeneration()
end end
return bridged.set_setting(k, v) return bridged.set_setting(k, v)
end end
@ -626,7 +628,7 @@ return function(_python, _bridged)
error("`KoboldLib.memory` must be a string; you attempted to set it to a "..type(v)) error("`KoboldLib.memory` must be a string; you attempted to set it to a "..type(v))
return return
end end
maybe_save_genmod_comparison_context() maybe_require_regeneration()
bridged.set_memory(v) bridged.set_memory(v)
end end