diff --git a/modeling/inference_models/hf.py b/modeling/inference_models/hf.py index b50ebf56..2417bffb 100644 --- a/modeling/inference_models/hf.py +++ b/modeling/inference_models/hf.py @@ -199,8 +199,9 @@ class HFInferenceModel(InferenceModel): pass def _post_load(self) -> None: + self.model_type = str(self.model_config.model_type) # These are model specific tokenizer overrides if a model has bad defaults - if utils.koboldai_vars.model_type == "llama": + if self.model_type == "llama": # Note: self.tokenizer is a GenericTokenizer, and self.tokenizer.tokenizer is the actual LlamaTokenizer self.tokenizer.add_bos_token = False @@ -284,23 +285,23 @@ class HFInferenceModel(InferenceModel): return result object.__setattr__(self.tokenizer, '__call__', call_wrapper.__get__(self.tokenizer)) - elif utils.koboldai_vars.model_type == "opt": + elif self.model_type == "opt": self.tokenizer._koboldai_header = self.tokenizer.encode("") self.tokenizer.add_bos_token = False self.tokenizer.add_prefix_space = False # Change newline behavior to match model quirks - if utils.koboldai_vars.model_type == "xglm": + if self.model_type == "xglm": # Default to newline mode if using XGLM utils.koboldai_vars.newlinemode = "s" - elif utils.koboldai_vars.model_type in ["opt", "bloom"]: + elif self.model_type in ["opt", "bloom"]: # Handle but don't convert newlines if using Fairseq models that have newlines trained in them utils.koboldai_vars.newlinemode = "ns" # Clean up tokens that cause issues if ( utils.koboldai_vars.badwordsids == koboldai_settings.badwordsids_default - and utils.koboldai_vars.model_type not in ("gpt2", "gpt_neo", "gptj") + and self.model_type not in ("gpt2", "gpt_neo", "gptj") ): utils.koboldai_vars.badwordsids = [ [v] @@ -357,15 +358,15 @@ class HFInferenceModel(InferenceModel): revision=utils.koboldai_vars.revision, cache_dir="cache", ) - utils.koboldai_vars.model_type = self.model_config.model_type + self.model_type = self.model_config.model_type except ValueError: - utils.koboldai_vars.model_type = { + self.model_type = { "NeoCustom": "gpt_neo", "GPT2Custom": "gpt2", - }.get(utils.koboldai_vars.model) + }.get(self.model) - if not utils.koboldai_vars.model_type: + if not self.model_type: logger.warning( "No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)" ) - utils.koboldai_vars.model_type = "gpt_neo" \ No newline at end of file + self.model_type = "gpt_neo" \ No newline at end of file