From f8aa578f4140b08a39f61a2f14021f16d5e4e6d2 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Sat, 11 Dec 2021 16:28:25 -0500 Subject: [PATCH] Enable generation modifiers for transformers backend only --- aiserver.py | 16 +++++++++++----- bridge.lua | 30 ++++++++++++++++-------------- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/aiserver.py b/aiserver.py index 6249bb4e..fc735bc0 100644 --- a/aiserver.py +++ b/aiserver.py @@ -690,7 +690,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme excluded_world_info: List[Set], head_length: int, ): - self.any_new_entries = False + self.regeneration_required = False self.tokenizer = tokenizer self.excluded_world_info = excluded_world_info self.head_length = head_length @@ -702,7 +702,13 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme ) -> bool: assert input_ids.ndim == 2 assert len(self.excluded_world_info) == input_ids.shape[0] - self.any_new_entries = False + self.regeneration_required = False + + vars.lua_koboldbridge.genmod() + if(vars.lua_koboldbridge.regeneration_required): + vars.lua_koboldbridge.regeneration_required = False + self.regeneration_required = True + if(not vars.dynamicscan): return False tail = input_ids[..., self.head_length:] @@ -711,9 +717,9 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme _, found = checkworldinfo(decoded, force_use_txt=True) found -= self.excluded_world_info[i] if(len(found) != 0): - self.any_new_entries = True + self.regeneration_required = True break - return self.any_new_entries + return self.regeneration_required old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria def new_get_stopping_criteria(self, *args, **kwargs): stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs) @@ -1910,7 +1916,7 @@ def generate(txt, minimum, maximum, found_entries=None): num_return_sequences=numseqs ) already_generated += len(genout[0]) - len(gen_in[0]) - if(not model.kai_scanner.any_new_entries): + if(not model.kai_scanner.regeneration_required): break assert genout.ndim >= 2 assert genout.shape[0] == vars.numseqs diff --git a/bridge.lua b/bridge.lua index 780e5788..25c4febb 100644 --- a/bridge.lua +++ b/bridge.lua @@ -180,15 +180,23 @@ return function(_python, _bridged) ---@class KoboldBridgeLib local koboldbridge = setmetatable({}, metawrapper) + koboldbridge.genmod_comparison_context = nil + koboldbridge.regeneration_required = false + koboldbridge.generating = true + koboldbridge.userstate = "inmod" + + ---@return nil + local function maybe_require_regeneration() + if koboldbridge.userstate == "genmod" and koboldbridge.genmod_comparison_context == nil then + koboldbridge.regeneration_required = true + end + end + --========================================================================== -- Userscript API: World Info --========================================================================== - local genmod_comparison_context = nil - koboldbridge.generating = true - koboldbridge.userstate = "inmod" - local fields = setmetatable({}, metawrapper) ---@param t KoboldWorldInfoEntry|KoboldWorldInfoFolder|KoboldWorldInfo|KoboldWorldInfoFolderSelector @@ -201,13 +209,6 @@ return function(_python, _bridged) return true end - ---@return nil - local function maybe_save_genmod_comparison_context() - if koboldbridge.userstate == "genmod" and genmod_comparison_context == nil then - genmod_comparison_context = kobold.worldinfo:compute_context() - end - end - ---------------------------------------------------------------------------- @@ -306,7 +307,7 @@ return function(_python, _bridged) return else if k ~= "comment" then - maybe_save_genmod_comparison_context() + maybe_require_regeneration() end bridged.set_attr(t.uid, k, v) return t @@ -589,6 +590,7 @@ return function(_python, _bridged) ---@param t KoboldSettings_base function KoboldSettings_mt.__newindex(t, k, v) if k == "gen_len" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 0 then + maybe_require_regeneration() bridged.set_gen_len(v) elseif k == "numseqs" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 1 then if koboldbridge.userstate == "genmod" then @@ -598,7 +600,7 @@ return function(_python, _bridged) bridged.set_numseqs(v) elseif type(k) == "string" and bridged.has_setting(k) and type(v) == type(bridged.get_setting(k)) then if k == "settknmax" or k == "anotedepth" or k == "setwidepth" or k == "setuseprompt" then - maybe_save_genmod_comparison_context() + maybe_require_regeneration() end return bridged.set_setting(k, v) end @@ -626,7 +628,7 @@ return function(_python, _bridged) error("`KoboldLib.memory` must be a string; you attempted to set it to a "..type(v)) return end - maybe_save_genmod_comparison_context() + maybe_require_regeneration() bridged.set_memory(v) end