mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Update stoppers.py
This commit is contained in:
@@ -116,28 +116,6 @@ class Stoppers:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def adventure_mode_stopper(
|
||||
model: InferenceModel,
|
||||
input_ids: torch.LongTensor,
|
||||
) -> bool:
|
||||
if not utils.koboldai_vars.adventure:
|
||||
return False
|
||||
|
||||
data = [model.tokenizer.decode(x) for x in input_ids]
|
||||
# null_character = model.tokenizer.encode(chr(0))[0]
|
||||
if "completed" not in model.gen_state:
|
||||
model.gen_state["completed"] = [False] * len(input_ids)
|
||||
|
||||
for i in range(len(input_ids)):
|
||||
if (data[i][-6:] == "> You " or data[i][-4:] == "You:"):
|
||||
model.gen_state["completed"][i] = True
|
||||
|
||||
if all(model.gen_state["completed"]):
|
||||
utils.koboldai_vars.generated_tkns = utils.koboldai_vars.genamt
|
||||
del model.gen_state["completed"]
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def stop_sequence_stopper(
|
||||
@@ -149,7 +127,12 @@ class Stoppers:
|
||||
# null_character = model.tokenizer.encode(chr(0))[0]
|
||||
if "completed" not in model.gen_state:
|
||||
model.gen_state["completed"] = [False] * len(input_ids)
|
||||
|
||||
if utils.koboldai_vars.adventure:
|
||||
extra_options = ["> You", "You:", "\n\n You", "\n\nYou", ". You"]
|
||||
for option in extra_options:
|
||||
if option not in utils.koboldai_vars.stop_sequence:
|
||||
utils.koboldai_vars.stop_sequence.append(option)
|
||||
|
||||
#one issue is that the stop sequence may not actual align with the end of token
|
||||
#if its a subsection of a longer token
|
||||
for stopper in utils.koboldai_vars.stop_sequence:
|
||||
@@ -163,6 +146,10 @@ class Stoppers:
|
||||
if all(model.gen_state["completed"]):
|
||||
utils.koboldai_vars.generated_tkns = utils.koboldai_vars.genamt
|
||||
del model.gen_state["completed"]
|
||||
if utils.koboldai_vars.adventure: # Remove added adventure mode stop sequences
|
||||
for option in extra_options:
|
||||
if option in utils.koboldai_vars.stop_sequence:
|
||||
utils.koboldai_vars.stop_sequence.remove(option)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
Reference in New Issue
Block a user