Complete the Lua generation halting API

This commit is contained in:
Gnome Ann
2021-12-12 12:52:03 -05:00
parent e06861bb0b
commit e2c3ac041b
2 changed files with 12 additions and 13 deletions

View File

@ -693,6 +693,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
head_length: int,
):
self.regeneration_required = False
self.halt = False
self.tokenizer = tokenizer
self.excluded_world_info = excluded_world_info
self.head_length = head_length
@ -705,11 +706,14 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
assert input_ids.ndim == 2
assert len(self.excluded_world_info) == input_ids.shape[0]
self.regeneration_required = False
self.halt = False
vars.lua_koboldbridge.genmod()
execute_genmod()
if(vars.lua_koboldbridge.regeneration_required):
vars.lua_koboldbridge.regeneration_required = False
self.regeneration_required = True
if(not vars.lua_koboldbridge.generating):
self.halt = True
if(not vars.dynamicscan):
return False
@ -721,7 +725,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
if(len(found) != 0):
self.regeneration_required = True
break
return self.regeneration_required
return self.regeneration_required or self.halt
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
def new_get_stopping_criteria(self, *args, **kwargs):
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
@ -2041,7 +2045,7 @@ def generate(txt, minimum, maximum, found_entries=None):
num_return_sequences=numseqs
)
already_generated += len(genout[0]) - len(gen_in[0])
if(not model.kai_scanner.regeneration_required):
if(model.kai_scanner.halt or not model.kai_scanner.regeneration_required):
break
assert genout.ndim >= 2
assert genout.shape[0] == vars.numseqs