Complete the Lua generation halting API

This commit is contained in:
Gnome Ann 2021-12-12 12:52:03 -05:00
parent e06861bb0b
commit e2c3ac041b
2 changed files with 12 additions and 13 deletions

View File

@ -693,6 +693,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
head_length: int,
):
self.regeneration_required = False
self.halt = False
self.tokenizer = tokenizer
self.excluded_world_info = excluded_world_info
self.head_length = head_length
@ -705,11 +706,14 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
assert input_ids.ndim == 2
assert len(self.excluded_world_info) == input_ids.shape[0]
self.regeneration_required = False
self.halt = False
vars.lua_koboldbridge.genmod()
execute_genmod()
if(vars.lua_koboldbridge.regeneration_required):
vars.lua_koboldbridge.regeneration_required = False
self.regeneration_required = True
if(not vars.lua_koboldbridge.generating):
self.halt = True
if(not vars.dynamicscan):
return False
@ -721,7 +725,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
if(len(found) != 0):
self.regeneration_required = True
break
return self.regeneration_required
return self.regeneration_required or self.halt
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
def new_get_stopping_criteria(self, *args, **kwargs):
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
@ -2041,7 +2045,7 @@ def generate(txt, minimum, maximum, found_entries=None):
num_return_sequences=numseqs
)
already_generated += len(genout[0]) - len(gen_in[0])
if(not model.kai_scanner.regeneration_required):
if(model.kai_scanner.halt or not model.kai_scanner.regeneration_required):
break
assert genout.ndim >= 2
assert genout.shape[0] == vars.numseqs

View File

@ -196,7 +196,6 @@ return function(_python, _bridged)
---@class KoboldBridgeLib
local koboldbridge = setmetatable({}, metawrapper)
koboldbridge.genmod_comparison_context = nil
koboldbridge.regeneration_required = false
koboldbridge.resend_settings_required = false
koboldbridge.generating = true
@ -204,7 +203,7 @@ return function(_python, _bridged)
---@return nil
local function maybe_require_regeneration()
if koboldbridge.userstate == "genmod" and koboldbridge.genmod_comparison_context == nil then
if koboldbridge.userstate == "genmod" then
koboldbridge.regeneration_required = true
end
end
@ -1372,7 +1371,7 @@ return function(_python, _bridged)
function koboldbridge.execute_inmod()
local r
koboldbridge.generating = true
koboldbridge.generating = false
koboldbridge.userstate = "inmod"
if koboldbridge.inmod ~= nil then
r = koboldbridge.inmod()
@ -1383,20 +1382,17 @@ return function(_python, _bridged)
---@return any, boolean
function koboldbridge.execute_genmod()
local r
local changed = false
koboldbridge.generating = true
koboldbridge.userstate = "genmod"
if koboldbridge.genmod ~= nil then
r = koboldbridge.genmod()
if genmod_comparison_context ~= kobold.worldinfo:compute_context() then
changed = true
genmod_comparison_context = nil
end
end
return r, changed
return r
end
function koboldbridge.execute_outmod()
local r
koboldbridge.generating = false
koboldbridge.userstate = "outmod"
if koboldbridge.outmod ~= nil then
r = koboldbridge.outmod()
@ -1404,7 +1400,6 @@ return function(_python, _bridged)
if koboldbridge.resend_settings_required then
bridged.resend_settings()
end
koboldbridge.generating = true
koboldbridge.userstate = "inmod"
return r
end