Model: Add singleline_stopper and fix stopper code

singleline_stopper adapted from MasterAibo in 0ba7ac9
This commit is contained in:
somebody
2023-03-04 18:30:40 -06:00
parent 6857d2f7e1
commit b02513df07
2 changed files with 35 additions and 2 deletions

View File

@@ -63,9 +63,13 @@ class HFTorchInferenceModel(HFInferenceModel):
self.low_mem = low_mem
self.post_token_hooks = [
Stoppers.core_stopper,
PostTokenHooks.stream_tokens,
]
self.stopper_hooks = [
Stoppers.core_stopper,
Stoppers.dynamic_wi_scanner,
Stoppers.singleline_stopper,
Stoppers.chat_mode_stopper,
]
@@ -104,7 +108,7 @@ class HFTorchInferenceModel(HFInferenceModel):
self._post_token_gen(input_ids)
for stopper in self.stopper_hooks:
do_stop = stopper(input_ids)
do_stop = stopper(self, input_ids)
if do_stop:
return True
return False

View File

@@ -115,3 +115,32 @@ class Stoppers:
del model.gen_state["completed"]
return True
return False
@staticmethod
def singleline_stopper(
model: InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
"""If singleline mode is enabled, it's pointless to generate output beyond the first newline."""
if not utils.koboldai_vars.singleline:
return False
# Keep track of presence of newlines in each sequence; we cannot stop a
# batch member individually, so we must wait for all of them to contain
# a newline.
if "newline_in_sequence" not in model.gen_state:
model.gen_state["newline_in_sequence"] = [False] * len(input_ids)
print(model.gen_state["newline_in_sequence"])
for sequence_idx, batch_sequence in enumerate(input_ids):
if model.tokenizer.decode(batch_sequence[-1]) == "\n":
model.gen_state["newline_in_sequence"][sequence_idx] = True
if all(model.gen_state["newline_in_sequence"]):
del model.gen_state["newline_in_sequence"]
print("OUT")
return True
print("nah its ok")
return False