From 933dbd634a2e61a556091d024a652039afee1c86 Mon Sep 17 00:00:00 2001 From: somebody Date: Mon, 1 May 2023 17:13:33 -0500 Subject: [PATCH] HFInferenceModel: Make badwordsids not unique to torch --- KoboldAI-Horde-Bridge | 2 +- modeling/inference_models/generic_hf_torch.py | 10 ---------- modeling/inference_models/hf.py | 18 ++++++++++++++++++ modeling/inference_models/hf_torch.py | 2 ++ 4 files changed, 21 insertions(+), 11 deletions(-) diff --git a/KoboldAI-Horde-Bridge b/KoboldAI-Horde-Bridge index d9014eba..7a732780 160000 --- a/KoboldAI-Horde-Bridge +++ b/KoboldAI-Horde-Bridge @@ -1 +1 @@ -Subproject commit d9014ebac969c0e5c37eb5456deebc3518130391 +Subproject commit 7a7327804ff10182adf8cda48e97784958699a49 diff --git a/modeling/inference_models/generic_hf_torch.py b/modeling/inference_models/generic_hf_torch.py index ce91b176..aa602b1a 100644 --- a/modeling/inference_models/generic_hf_torch.py +++ b/modeling/inference_models/generic_hf_torch.py @@ -239,16 +239,6 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel): ) shutil.rmtree("cache/") - if ( - utils.koboldai_vars.badwordsids is koboldai_settings.badwordsids_default - and utils.koboldai_vars.model_type not in ("gpt2", "gpt_neo", "gptj") - ): - utils.koboldai_vars.badwordsids = [ - [v] - for k, v in self.tokenizer.get_vocab().items() - if any(c in str(k) for c in "[]") - ] - self.patch_embedding() if utils.koboldai_vars.hascuda: diff --git a/modeling/inference_models/hf.py b/modeling/inference_models/hf.py index eae4bb2d..eac5284f 100644 --- a/modeling/inference_models/hf.py +++ b/modeling/inference_models/hf.py @@ -3,6 +3,7 @@ from typing import Optional from transformers import AutoConfig import utils +import koboldai_settings from logger import logger from modeling.inference_model import InferenceModel @@ -16,6 +17,23 @@ class HFInferenceModel(InferenceModel): self.model = None self.tokenizer = None + def _post_load(self) -> None: + # Clean up tokens that cause issues + if ( + utils.koboldai_vars.badwordsids == koboldai_settings.badwordsids_default + and utils.koboldai_vars.model_type not in ("gpt2", "gpt_neo", "gptj") + ): + utils.koboldai_vars.badwordsids = [ + [v] + for k, v in self.tokenizer.get_vocab().items() + if any(c in str(k) for c in "[]") + ] + + if utils.koboldai_vars.newlinemode == "n": + utils.koboldai_vars.badwordsids.append([self.tokenizer.eos_token_id]) + + return super()._post_load() + def get_local_model_path( self, legacy: bool = False, ignore_existance: bool = False ) -> Optional[str]: diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index 890a9e8e..dc445348 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -220,6 +220,8 @@ class HFTorchInferenceModel(HFInferenceModel): new_sample.old_sample = transformers.GenerationMixin.sample use_core_manipulations.sample = new_sample + return super()._post_load() + def _raw_generate( self, prompt_tokens: Union[List[int], torch.Tensor],