From 467f2f25eb489b33b616a71d0b3566f3fd4ce1b5 Mon Sep 17 00:00:00 2001 From: onesome Date: Wed, 26 Apr 2023 16:58:33 -0500 Subject: [PATCH] More loading fixes --- modeling/inference_models/hf_torch.py | 35 ++++++++++++--------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index 16186872..890a9e8e 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -270,13 +270,21 @@ class HFTorchInferenceModel(HFInferenceModel): ) def _get_model(self, location: str, tf_kwargs: Dict): + tf_kwargs["revision"] = utils.koboldai_vars.revision + tf_kwargs["cache_dir"] = "cache" + + # If we have model hints for legacy model, use them rather than fall back. try: - return AutoModelForCausalLM.from_pretrained( - location, - revision=utils.koboldai_vars.revision, - cache_dir="cache", - **tf_kwargs, - ) + if self.model_name == "GPT2Custom": + return GPT2LMHeadModel.from_pretrained(location, **tf_kwargs) + elif self.model_name == "NeoCustom": + return GPTNeoForCausalLM.from_pretrained(location, **tf_kwargs) + except Exception as e: + logger.warning(f"{self.model_name} is a no-go; {e} - Falling back to auto.") + + # Try to determine model type from either AutoModel or falling back to legacy + try: + return AutoModelForCausalLM.from_pretrained(location, **tf_kwargs) except Exception as e: if "out of memory" in traceback.format_exc().lower(): raise RuntimeError( @@ -284,22 +292,11 @@ class HFTorchInferenceModel(HFInferenceModel): ) logger.warning(f"Fell back to GPT2LMHeadModel due to {e}") - try: - return GPT2LMHeadModel.from_pretrained( - location, - revision=utils.koboldai_vars.revision, - cache_dir="cache", - **tf_kwargs, - ) + return GPT2LMHeadModel.from_pretrained(location, **tf_kwargs) except Exception as e: logger.warning(f"Fell back to GPTNeoForCausalLM due to {e}") - return GPTNeoForCausalLM.from_pretrained( - location, - revision=utils.koboldai_vars.revision, - cache_dir="cache", - **tf_kwargs, - ) + return GPTNeoForCausalLM.from_pretrained(location, **tf_kwargs) def get_hidden_size(self) -> int: return self.model.get_input_embeddings().embedding_dim