From b02513df07bfbc74a0d96d4c714ccdf4b0c84932 Mon Sep 17 00:00:00 2001 From: somebody Date: Sat, 4 Mar 2023 18:30:40 -0600 Subject: [PATCH] Model: Add singleline_stopper and fix stopper code singleline_stopper adapted from MasterAibo in 0ba7ac9 --- modeling/inference_models/hf_torch.py | 8 ++++++-- modeling/stoppers.py | 29 +++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index 94f57272..fffe08cc 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -63,9 +63,13 @@ class HFTorchInferenceModel(HFInferenceModel): self.low_mem = low_mem self.post_token_hooks = [ - Stoppers.core_stopper, PostTokenHooks.stream_tokens, + ] + + self.stopper_hooks = [ + Stoppers.core_stopper, Stoppers.dynamic_wi_scanner, + Stoppers.singleline_stopper, Stoppers.chat_mode_stopper, ] @@ -104,7 +108,7 @@ class HFTorchInferenceModel(HFInferenceModel): self._post_token_gen(input_ids) for stopper in self.stopper_hooks: - do_stop = stopper(input_ids) + do_stop = stopper(self, input_ids) if do_stop: return True return False diff --git a/modeling/stoppers.py b/modeling/stoppers.py index 2cb8af49..b4eeaf51 100644 --- a/modeling/stoppers.py +++ b/modeling/stoppers.py @@ -115,3 +115,32 @@ class Stoppers: del model.gen_state["completed"] return True return False + + @staticmethod + def singleline_stopper( + model: InferenceModel, + input_ids: torch.LongTensor, + ) -> bool: + """If singleline mode is enabled, it's pointless to generate output beyond the first newline.""" + + if not utils.koboldai_vars.singleline: + return False + + # Keep track of presence of newlines in each sequence; we cannot stop a + # batch member individually, so we must wait for all of them to contain + # a newline. + if "newline_in_sequence" not in model.gen_state: + model.gen_state["newline_in_sequence"] = [False] * len(input_ids) + + print(model.gen_state["newline_in_sequence"]) + + for sequence_idx, batch_sequence in enumerate(input_ids): + if model.tokenizer.decode(batch_sequence[-1]) == "\n": + model.gen_state["newline_in_sequence"][sequence_idx] = True + + if all(model.gen_state["newline_in_sequence"]): + del model.gen_state["newline_in_sequence"] + print("OUT") + return True + print("nah its ok") + return False