Enable generation modifiers for transformers backend only

This commit is contained in:
Gnome Ann
2021-12-11 16:28:25 -05:00
parent 1111408cc2
commit f8aa578f41
2 changed files with 27 additions and 19 deletions

View File

@ -690,7 +690,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
excluded_world_info: List[Set],
head_length: int,
):
self.any_new_entries = False
self.regeneration_required = False
self.tokenizer = tokenizer
self.excluded_world_info = excluded_world_info
self.head_length = head_length
@ -702,7 +702,13 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
) -> bool:
assert input_ids.ndim == 2
assert len(self.excluded_world_info) == input_ids.shape[0]
self.any_new_entries = False
self.regeneration_required = False
vars.lua_koboldbridge.genmod()
if(vars.lua_koboldbridge.regeneration_required):
vars.lua_koboldbridge.regeneration_required = False
self.regeneration_required = True
if(not vars.dynamicscan):
return False
tail = input_ids[..., self.head_length:]
@ -711,9 +717,9 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
_, found = checkworldinfo(decoded, force_use_txt=True)
found -= self.excluded_world_info[i]
if(len(found) != 0):
self.any_new_entries = True
self.regeneration_required = True
break
return self.any_new_entries
return self.regeneration_required
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
def new_get_stopping_criteria(self, *args, **kwargs):
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
@ -1910,7 +1916,7 @@ def generate(txt, minimum, maximum, found_entries=None):
num_return_sequences=numseqs
)
already_generated += len(genout[0]) - len(gen_in[0])
if(not model.kai_scanner.any_new_entries):
if(not model.kai_scanner.regeneration_required):
break
assert genout.ndim >= 2
assert genout.shape[0] == vars.numseqs