More loading fixes

This commit is contained in:
onesome
2023-04-26 16:58:33 -05:00
parent d4f7b60dc9
commit 467f2f25eb

View File

@@ -270,13 +270,21 @@ class HFTorchInferenceModel(HFInferenceModel):
) )
def _get_model(self, location: str, tf_kwargs: Dict): def _get_model(self, location: str, tf_kwargs: Dict):
tf_kwargs["revision"] = utils.koboldai_vars.revision
tf_kwargs["cache_dir"] = "cache"
# If we have model hints for legacy model, use them rather than fall back.
try: try:
return AutoModelForCausalLM.from_pretrained( if self.model_name == "GPT2Custom":
location, return GPT2LMHeadModel.from_pretrained(location, **tf_kwargs)
revision=utils.koboldai_vars.revision, elif self.model_name == "NeoCustom":
cache_dir="cache", return GPTNeoForCausalLM.from_pretrained(location, **tf_kwargs)
**tf_kwargs, except Exception as e:
) logger.warning(f"{self.model_name} is a no-go; {e} - Falling back to auto.")
# Try to determine model type from either AutoModel or falling back to legacy
try:
return AutoModelForCausalLM.from_pretrained(location, **tf_kwargs)
except Exception as e: except Exception as e:
if "out of memory" in traceback.format_exc().lower(): if "out of memory" in traceback.format_exc().lower():
raise RuntimeError( raise RuntimeError(
@@ -284,22 +292,11 @@ class HFTorchInferenceModel(HFInferenceModel):
) )
logger.warning(f"Fell back to GPT2LMHeadModel due to {e}") logger.warning(f"Fell back to GPT2LMHeadModel due to {e}")
try: try:
return GPT2LMHeadModel.from_pretrained( return GPT2LMHeadModel.from_pretrained(location, **tf_kwargs)
location,
revision=utils.koboldai_vars.revision,
cache_dir="cache",
**tf_kwargs,
)
except Exception as e: except Exception as e:
logger.warning(f"Fell back to GPTNeoForCausalLM due to {e}") logger.warning(f"Fell back to GPTNeoForCausalLM due to {e}")
return GPTNeoForCausalLM.from_pretrained( return GPTNeoForCausalLM.from_pretrained(location, **tf_kwargs)
location,
revision=utils.koboldai_vars.revision,
cache_dir="cache",
**tf_kwargs,
)
def get_hidden_size(self) -> int: def get_hidden_size(self) -> int:
return self.model.get_input_embeddings().embedding_dim return self.model.get_input_embeddings().embedding_dim