mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Complete the Lua generation halting API
This commit is contained in:
10
aiserver.py
10
aiserver.py
@ -693,6 +693,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
head_length: int,
|
||||
):
|
||||
self.regeneration_required = False
|
||||
self.halt = False
|
||||
self.tokenizer = tokenizer
|
||||
self.excluded_world_info = excluded_world_info
|
||||
self.head_length = head_length
|
||||
@ -705,11 +706,14 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
assert input_ids.ndim == 2
|
||||
assert len(self.excluded_world_info) == input_ids.shape[0]
|
||||
self.regeneration_required = False
|
||||
self.halt = False
|
||||
|
||||
vars.lua_koboldbridge.genmod()
|
||||
execute_genmod()
|
||||
if(vars.lua_koboldbridge.regeneration_required):
|
||||
vars.lua_koboldbridge.regeneration_required = False
|
||||
self.regeneration_required = True
|
||||
if(not vars.lua_koboldbridge.generating):
|
||||
self.halt = True
|
||||
|
||||
if(not vars.dynamicscan):
|
||||
return False
|
||||
@ -721,7 +725,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
if(len(found) != 0):
|
||||
self.regeneration_required = True
|
||||
break
|
||||
return self.regeneration_required
|
||||
return self.regeneration_required or self.halt
|
||||
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
|
||||
def new_get_stopping_criteria(self, *args, **kwargs):
|
||||
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
|
||||
@ -2041,7 +2045,7 @@ def generate(txt, minimum, maximum, found_entries=None):
|
||||
num_return_sequences=numseqs
|
||||
)
|
||||
already_generated += len(genout[0]) - len(gen_in[0])
|
||||
if(not model.kai_scanner.regeneration_required):
|
||||
if(model.kai_scanner.halt or not model.kai_scanner.regeneration_required):
|
||||
break
|
||||
assert genout.ndim >= 2
|
||||
assert genout.shape[0] == vars.numseqs
|
||||
|
Reference in New Issue
Block a user