Update stoppers.py

This commit is contained in:
YellowRoseCx
2023-06-29 02:34:08 -05:00
committed by GitHub
parent ff31a0bc86
commit 91d543bf5a

View File

@@ -116,28 +116,6 @@ class Stoppers:
return True
return False
@staticmethod
def adventure_mode_stopper(
model: InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
if not utils.koboldai_vars.adventure:
return False
data = [model.tokenizer.decode(x) for x in input_ids]
# null_character = model.tokenizer.encode(chr(0))[0]
if "completed" not in model.gen_state:
model.gen_state["completed"] = [False] * len(input_ids)
for i in range(len(input_ids)):
if (data[i][-6:] == "> You " or data[i][-4:] == "You:"):
model.gen_state["completed"][i] = True
if all(model.gen_state["completed"]):
utils.koboldai_vars.generated_tkns = utils.koboldai_vars.genamt
del model.gen_state["completed"]
return True
return False
@staticmethod
def stop_sequence_stopper(
@@ -149,7 +127,12 @@ class Stoppers:
# null_character = model.tokenizer.encode(chr(0))[0]
if "completed" not in model.gen_state:
model.gen_state["completed"] = [False] * len(input_ids)
if utils.koboldai_vars.adventure:
extra_options = ["> You", "You:", "\n\n You", "\n\nYou", ". You"]
for option in extra_options:
if option not in utils.koboldai_vars.stop_sequence:
utils.koboldai_vars.stop_sequence.append(option)
#one issue is that the stop sequence may not actual align with the end of token
#if its a subsection of a longer token
for stopper in utils.koboldai_vars.stop_sequence:
@@ -163,6 +146,10 @@ class Stoppers:
if all(model.gen_state["completed"]):
utils.koboldai_vars.generated_tkns = utils.koboldai_vars.genamt
del model.gen_state["completed"]
if utils.koboldai_vars.adventure: # Remove added adventure mode stop sequences
for option in extra_options:
if option in utils.koboldai_vars.stop_sequence:
utils.koboldai_vars.stop_sequence.remove(option)
return True
return False