diff --git a/modeling/inference_models/hf.py b/modeling/inference_models/hf.py index 63c0a40d..013590ef 100644 --- a/modeling/inference_models/hf.py +++ b/modeling/inference_models/hf.py @@ -18,20 +18,6 @@ class HFInferenceModel(InferenceModel): self.tokenizer = None def _post_load(self) -> None: - # Clean up tokens that cause issues - if ( - utils.koboldai_vars.badwordsids == koboldai_settings.badwordsids_default - and utils.koboldai_vars.model_type not in ("gpt2", "gpt_neo", "gptj") - ): - utils.koboldai_vars.badwordsids = [ - [v] - for k, v in self.tokenizer.get_vocab().items() - if any(c in str(k) for c in "[]") - ] - - if utils.koboldai_vars.newlinemode == "n": - utils.koboldai_vars.badwordsids.append([self.tokenizer.eos_token_id]) - # These are model specific tokenizer overrides if a model has bad defaults if utils.koboldai_vars.model_type == "llama": self.tokenizer.decode_with_prefix_space = True @@ -49,6 +35,20 @@ class HFInferenceModel(InferenceModel): # Handle but don't convert newlines if using Fairseq models that have newlines trained in them utils.koboldai_vars.newlinemode = "ns" + # Clean up tokens that cause issues + if ( + utils.koboldai_vars.badwordsids == koboldai_settings.badwordsids_default + and utils.koboldai_vars.model_type not in ("gpt2", "gpt_neo", "gptj") + ): + utils.koboldai_vars.badwordsids = [ + [v] + for k, v in self.tokenizer.get_vocab().items() + if any(c in str(k) for c in "[]") + ] + + if utils.koboldai_vars.newlinemode == "n": + utils.koboldai_vars.badwordsids.append([self.tokenizer.eos_token_id]) + return super()._post_load() def get_local_model_path(