Complete the Lua generation halting API

This commit is contained in:
Gnome Ann 2021-12-12 12:52:03 -05:00
parent e06861bb0b
commit e2c3ac041b
2 changed files with 12 additions and 13 deletions

View File

@ -693,6 +693,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
head_length: int, head_length: int,
): ):
self.regeneration_required = False self.regeneration_required = False
self.halt = False
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.excluded_world_info = excluded_world_info self.excluded_world_info = excluded_world_info
self.head_length = head_length self.head_length = head_length
@ -705,11 +706,14 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
assert input_ids.ndim == 2 assert input_ids.ndim == 2
assert len(self.excluded_world_info) == input_ids.shape[0] assert len(self.excluded_world_info) == input_ids.shape[0]
self.regeneration_required = False self.regeneration_required = False
self.halt = False
vars.lua_koboldbridge.genmod() execute_genmod()
if(vars.lua_koboldbridge.regeneration_required): if(vars.lua_koboldbridge.regeneration_required):
vars.lua_koboldbridge.regeneration_required = False vars.lua_koboldbridge.regeneration_required = False
self.regeneration_required = True self.regeneration_required = True
if(not vars.lua_koboldbridge.generating):
self.halt = True
if(not vars.dynamicscan): if(not vars.dynamicscan):
return False return False
@ -721,7 +725,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
if(len(found) != 0): if(len(found) != 0):
self.regeneration_required = True self.regeneration_required = True
break break
return self.regeneration_required return self.regeneration_required or self.halt
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
def new_get_stopping_criteria(self, *args, **kwargs): def new_get_stopping_criteria(self, *args, **kwargs):
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs) stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
@ -2041,7 +2045,7 @@ def generate(txt, minimum, maximum, found_entries=None):
num_return_sequences=numseqs num_return_sequences=numseqs
) )
already_generated += len(genout[0]) - len(gen_in[0]) already_generated += len(genout[0]) - len(gen_in[0])
if(not model.kai_scanner.regeneration_required): if(model.kai_scanner.halt or not model.kai_scanner.regeneration_required):
break break
assert genout.ndim >= 2 assert genout.ndim >= 2
assert genout.shape[0] == vars.numseqs assert genout.shape[0] == vars.numseqs

View File

@ -196,7 +196,6 @@ return function(_python, _bridged)
---@class KoboldBridgeLib ---@class KoboldBridgeLib
local koboldbridge = setmetatable({}, metawrapper) local koboldbridge = setmetatable({}, metawrapper)
koboldbridge.genmod_comparison_context = nil
koboldbridge.regeneration_required = false koboldbridge.regeneration_required = false
koboldbridge.resend_settings_required = false koboldbridge.resend_settings_required = false
koboldbridge.generating = true koboldbridge.generating = true
@ -204,7 +203,7 @@ return function(_python, _bridged)
---@return nil ---@return nil
local function maybe_require_regeneration() local function maybe_require_regeneration()
if koboldbridge.userstate == "genmod" and koboldbridge.genmod_comparison_context == nil then if koboldbridge.userstate == "genmod" then
koboldbridge.regeneration_required = true koboldbridge.regeneration_required = true
end end
end end
@ -1372,7 +1371,7 @@ return function(_python, _bridged)
function koboldbridge.execute_inmod() function koboldbridge.execute_inmod()
local r local r
koboldbridge.generating = true koboldbridge.generating = false
koboldbridge.userstate = "inmod" koboldbridge.userstate = "inmod"
if koboldbridge.inmod ~= nil then if koboldbridge.inmod ~= nil then
r = koboldbridge.inmod() r = koboldbridge.inmod()
@ -1383,20 +1382,17 @@ return function(_python, _bridged)
---@return any, boolean ---@return any, boolean
function koboldbridge.execute_genmod() function koboldbridge.execute_genmod()
local r local r
local changed = false koboldbridge.generating = true
koboldbridge.userstate = "genmod" koboldbridge.userstate = "genmod"
if koboldbridge.genmod ~= nil then if koboldbridge.genmod ~= nil then
r = koboldbridge.genmod() r = koboldbridge.genmod()
if genmod_comparison_context ~= kobold.worldinfo:compute_context() then
changed = true
genmod_comparison_context = nil
end
end end
return r, changed return r
end end
function koboldbridge.execute_outmod() function koboldbridge.execute_outmod()
local r local r
koboldbridge.generating = false
koboldbridge.userstate = "outmod" koboldbridge.userstate = "outmod"
if koboldbridge.outmod ~= nil then if koboldbridge.outmod ~= nil then
r = koboldbridge.outmod() r = koboldbridge.outmod()
@ -1404,7 +1400,6 @@ return function(_python, _bridged)
if koboldbridge.resend_settings_required then if koboldbridge.resend_settings_required then
bridged.resend_settings() bridged.resend_settings()
end end
koboldbridge.generating = true
koboldbridge.userstate = "inmod" koboldbridge.userstate = "inmod"
return r return r
end end