mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Enable generation modifiers for transformers backend only
This commit is contained in:
16
aiserver.py
16
aiserver.py
@ -690,7 +690,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
excluded_world_info: List[Set],
|
||||
head_length: int,
|
||||
):
|
||||
self.any_new_entries = False
|
||||
self.regeneration_required = False
|
||||
self.tokenizer = tokenizer
|
||||
self.excluded_world_info = excluded_world_info
|
||||
self.head_length = head_length
|
||||
@ -702,7 +702,13 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
) -> bool:
|
||||
assert input_ids.ndim == 2
|
||||
assert len(self.excluded_world_info) == input_ids.shape[0]
|
||||
self.any_new_entries = False
|
||||
self.regeneration_required = False
|
||||
|
||||
vars.lua_koboldbridge.genmod()
|
||||
if(vars.lua_koboldbridge.regeneration_required):
|
||||
vars.lua_koboldbridge.regeneration_required = False
|
||||
self.regeneration_required = True
|
||||
|
||||
if(not vars.dynamicscan):
|
||||
return False
|
||||
tail = input_ids[..., self.head_length:]
|
||||
@ -711,9 +717,9 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
_, found = checkworldinfo(decoded, force_use_txt=True)
|
||||
found -= self.excluded_world_info[i]
|
||||
if(len(found) != 0):
|
||||
self.any_new_entries = True
|
||||
self.regeneration_required = True
|
||||
break
|
||||
return self.any_new_entries
|
||||
return self.regeneration_required
|
||||
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
|
||||
def new_get_stopping_criteria(self, *args, **kwargs):
|
||||
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
|
||||
@ -1910,7 +1916,7 @@ def generate(txt, minimum, maximum, found_entries=None):
|
||||
num_return_sequences=numseqs
|
||||
)
|
||||
already_generated += len(genout[0]) - len(gen_in[0])
|
||||
if(not model.kai_scanner.any_new_entries):
|
||||
if(not model.kai_scanner.regeneration_required):
|
||||
break
|
||||
assert genout.ndim >= 2
|
||||
assert genout.shape[0] == vars.numseqs
|
||||
|
Reference in New Issue
Block a user