mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-03-01 18:07:50 +01:00
Enable generation modifiers for transformers backend only
This commit is contained in:
parent
1111408cc2
commit
f8aa578f41
16
aiserver.py
16
aiserver.py
@ -690,7 +690,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||||||
excluded_world_info: List[Set],
|
excluded_world_info: List[Set],
|
||||||
head_length: int,
|
head_length: int,
|
||||||
):
|
):
|
||||||
self.any_new_entries = False
|
self.regeneration_required = False
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.excluded_world_info = excluded_world_info
|
self.excluded_world_info = excluded_world_info
|
||||||
self.head_length = head_length
|
self.head_length = head_length
|
||||||
@ -702,7 +702,13 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
assert input_ids.ndim == 2
|
assert input_ids.ndim == 2
|
||||||
assert len(self.excluded_world_info) == input_ids.shape[0]
|
assert len(self.excluded_world_info) == input_ids.shape[0]
|
||||||
self.any_new_entries = False
|
self.regeneration_required = False
|
||||||
|
|
||||||
|
vars.lua_koboldbridge.genmod()
|
||||||
|
if(vars.lua_koboldbridge.regeneration_required):
|
||||||
|
vars.lua_koboldbridge.regeneration_required = False
|
||||||
|
self.regeneration_required = True
|
||||||
|
|
||||||
if(not vars.dynamicscan):
|
if(not vars.dynamicscan):
|
||||||
return False
|
return False
|
||||||
tail = input_ids[..., self.head_length:]
|
tail = input_ids[..., self.head_length:]
|
||||||
@ -711,9 +717,9 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||||||
_, found = checkworldinfo(decoded, force_use_txt=True)
|
_, found = checkworldinfo(decoded, force_use_txt=True)
|
||||||
found -= self.excluded_world_info[i]
|
found -= self.excluded_world_info[i]
|
||||||
if(len(found) != 0):
|
if(len(found) != 0):
|
||||||
self.any_new_entries = True
|
self.regeneration_required = True
|
||||||
break
|
break
|
||||||
return self.any_new_entries
|
return self.regeneration_required
|
||||||
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
|
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
|
||||||
def new_get_stopping_criteria(self, *args, **kwargs):
|
def new_get_stopping_criteria(self, *args, **kwargs):
|
||||||
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
|
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
|
||||||
@ -1910,7 +1916,7 @@ def generate(txt, minimum, maximum, found_entries=None):
|
|||||||
num_return_sequences=numseqs
|
num_return_sequences=numseqs
|
||||||
)
|
)
|
||||||
already_generated += len(genout[0]) - len(gen_in[0])
|
already_generated += len(genout[0]) - len(gen_in[0])
|
||||||
if(not model.kai_scanner.any_new_entries):
|
if(not model.kai_scanner.regeneration_required):
|
||||||
break
|
break
|
||||||
assert genout.ndim >= 2
|
assert genout.ndim >= 2
|
||||||
assert genout.shape[0] == vars.numseqs
|
assert genout.shape[0] == vars.numseqs
|
||||||
|
30
bridge.lua
30
bridge.lua
@ -180,15 +180,23 @@ return function(_python, _bridged)
|
|||||||
---@class KoboldBridgeLib
|
---@class KoboldBridgeLib
|
||||||
local koboldbridge = setmetatable({}, metawrapper)
|
local koboldbridge = setmetatable({}, metawrapper)
|
||||||
|
|
||||||
|
koboldbridge.genmod_comparison_context = nil
|
||||||
|
koboldbridge.regeneration_required = false
|
||||||
|
koboldbridge.generating = true
|
||||||
|
koboldbridge.userstate = "inmod"
|
||||||
|
|
||||||
|
---@return nil
|
||||||
|
local function maybe_require_regeneration()
|
||||||
|
if koboldbridge.userstate == "genmod" and koboldbridge.genmod_comparison_context == nil then
|
||||||
|
koboldbridge.regeneration_required = true
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
--==========================================================================
|
--==========================================================================
|
||||||
-- Userscript API: World Info
|
-- Userscript API: World Info
|
||||||
--==========================================================================
|
--==========================================================================
|
||||||
|
|
||||||
local genmod_comparison_context = nil
|
|
||||||
koboldbridge.generating = true
|
|
||||||
koboldbridge.userstate = "inmod"
|
|
||||||
|
|
||||||
local fields = setmetatable({}, metawrapper)
|
local fields = setmetatable({}, metawrapper)
|
||||||
|
|
||||||
---@param t KoboldWorldInfoEntry|KoboldWorldInfoFolder|KoboldWorldInfo|KoboldWorldInfoFolderSelector
|
---@param t KoboldWorldInfoEntry|KoboldWorldInfoFolder|KoboldWorldInfo|KoboldWorldInfoFolderSelector
|
||||||
@ -201,13 +209,6 @@ return function(_python, _bridged)
|
|||||||
return true
|
return true
|
||||||
end
|
end
|
||||||
|
|
||||||
---@return nil
|
|
||||||
local function maybe_save_genmod_comparison_context()
|
|
||||||
if koboldbridge.userstate == "genmod" and genmod_comparison_context == nil then
|
|
||||||
genmod_comparison_context = kobold.worldinfo:compute_context()
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
|
|
||||||
----------------------------------------------------------------------------
|
----------------------------------------------------------------------------
|
||||||
|
|
||||||
@ -306,7 +307,7 @@ return function(_python, _bridged)
|
|||||||
return
|
return
|
||||||
else
|
else
|
||||||
if k ~= "comment" then
|
if k ~= "comment" then
|
||||||
maybe_save_genmod_comparison_context()
|
maybe_require_regeneration()
|
||||||
end
|
end
|
||||||
bridged.set_attr(t.uid, k, v)
|
bridged.set_attr(t.uid, k, v)
|
||||||
return t
|
return t
|
||||||
@ -589,6 +590,7 @@ return function(_python, _bridged)
|
|||||||
---@param t KoboldSettings_base
|
---@param t KoboldSettings_base
|
||||||
function KoboldSettings_mt.__newindex(t, k, v)
|
function KoboldSettings_mt.__newindex(t, k, v)
|
||||||
if k == "gen_len" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 0 then
|
if k == "gen_len" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 0 then
|
||||||
|
maybe_require_regeneration()
|
||||||
bridged.set_gen_len(v)
|
bridged.set_gen_len(v)
|
||||||
elseif k == "numseqs" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 1 then
|
elseif k == "numseqs" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 1 then
|
||||||
if koboldbridge.userstate == "genmod" then
|
if koboldbridge.userstate == "genmod" then
|
||||||
@ -598,7 +600,7 @@ return function(_python, _bridged)
|
|||||||
bridged.set_numseqs(v)
|
bridged.set_numseqs(v)
|
||||||
elseif type(k) == "string" and bridged.has_setting(k) and type(v) == type(bridged.get_setting(k)) then
|
elseif type(k) == "string" and bridged.has_setting(k) and type(v) == type(bridged.get_setting(k)) then
|
||||||
if k == "settknmax" or k == "anotedepth" or k == "setwidepth" or k == "setuseprompt" then
|
if k == "settknmax" or k == "anotedepth" or k == "setwidepth" or k == "setuseprompt" then
|
||||||
maybe_save_genmod_comparison_context()
|
maybe_require_regeneration()
|
||||||
end
|
end
|
||||||
return bridged.set_setting(k, v)
|
return bridged.set_setting(k, v)
|
||||||
end
|
end
|
||||||
@ -626,7 +628,7 @@ return function(_python, _bridged)
|
|||||||
error("`KoboldLib.memory` must be a string; you attempted to set it to a "..type(v))
|
error("`KoboldLib.memory` must be a string; you attempted to set it to a "..type(v))
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
maybe_save_genmod_comparison_context()
|
maybe_require_regeneration()
|
||||||
bridged.set_memory(v)
|
bridged.set_memory(v)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user