From e5d0a597a1806815ca7463a6536d6719ceb8d165 Mon Sep 17 00:00:00 2001 From: somebody Date: Fri, 21 Jul 2023 15:36:32 -0500 Subject: [PATCH] Generation Mode: UNTIL_EOS This mode enables the EOS token and will generate infinitely until hitting it. --- modeling/inference_model.py | 24 +++++++++++++----------- modeling/inference_models/hf_torch.py | 13 ++++++++++++- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/modeling/inference_model.py b/modeling/inference_model.py index e09249c3..8b7f0e3e 100644 --- a/modeling/inference_model.py +++ b/modeling/inference_model.py @@ -585,19 +585,21 @@ class InferenceModel: "wi_scanner_excluded_keys", set() ) + self.gen_state["allow_eos"] = False + temp_stoppers = [] + if gen_mode not in self.get_supported_gen_modes(): + gen_mode = GenerationMode.STANDARD + logger.warning(f"User requested unsupported GenerationMode '{gen_mode}'!") + if gen_mode == GenerationMode.FOREVER: - if self.capabilties.stopper_hooks: - self.gen_state["stop_at_genamt"] = False - max_new = 1e7 - else: - logger.warning( - "User requested infinite generation on model that doesn't support stop hooks. Recipe for disaster!" - ) + self.gen_state["stop_at_genamt"] = False + max_new = 1e7 elif gen_mode == GenerationMode.UNTIL_EOS: - # Still need to unban - raise NotImplementedError() + self.gen_state["allow_eos"] = True + self.gen_state["stop_at_genamt"] = False + max_new = 1e7 elif gen_mode == GenerationMode.UNTIL_NEWLINE: # TODO: Look into replacing `single_line` with `generation_mode` temp_stoppers.append(Stoppers.newline_stopper) @@ -668,11 +670,11 @@ class InferenceModel: Returns: List[GenerationMode]: A list of compatible `GenerationMode`s. """ - ret = [] + ret = [GenerationMode.STANDARD] + if self.capabilties.stopper_hooks: ret += [ GenerationMode.FOREVER, - GenerationMode.UNTIL_EOS, GenerationMode.UNTIL_NEWLINE, GenerationMode.UNTIL_SENTENCE_END, ] diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index 1b411c95..b4909f60 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -31,6 +31,7 @@ from modeling.stoppers import Stoppers from modeling.post_token_hooks import PostTokenHooks from modeling.inference_models.hf import HFInferenceModel from modeling.inference_model import ( + GenerationMode, GenerationResult, GenerationSettings, ModelCapabilities, @@ -254,7 +255,11 @@ class HFTorchInferenceModel(HFInferenceModel): kwargs["logits_warper"] = new_get_logits_warper( beams=1, ) - if utils.koboldai_vars.newlinemode in ["s", "ns"]: + + if ( + utils.koboldai_vars.newlinemode in ["s", "ns"] + and not m_self.gen_state["allow_eos"] + ): kwargs["eos_token_id"] = -1 kwargs.setdefault("pad_token_id", 2) return new_sample.old_sample(self, *args, **kwargs) @@ -605,3 +610,9 @@ class HFTorchInferenceModel(HFInferenceModel): self.breakmodel = False self.usegpu = False return + + def get_supported_gen_modes(self) -> List[GenerationMode]: + # This changes a torch patch to disallow eos as a bad word. + return super().get_supported_gen_modes() + [ + GenerationMode.UNTIL_EOS + ] \ No newline at end of file