mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
More loading fixes
This commit is contained in:
@@ -270,13 +270,21 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
)
|
||||
|
||||
def _get_model(self, location: str, tf_kwargs: Dict):
|
||||
tf_kwargs["revision"] = utils.koboldai_vars.revision
|
||||
tf_kwargs["cache_dir"] = "cache"
|
||||
|
||||
# If we have model hints for legacy model, use them rather than fall back.
|
||||
try:
|
||||
return AutoModelForCausalLM.from_pretrained(
|
||||
location,
|
||||
revision=utils.koboldai_vars.revision,
|
||||
cache_dir="cache",
|
||||
**tf_kwargs,
|
||||
)
|
||||
if self.model_name == "GPT2Custom":
|
||||
return GPT2LMHeadModel.from_pretrained(location, **tf_kwargs)
|
||||
elif self.model_name == "NeoCustom":
|
||||
return GPTNeoForCausalLM.from_pretrained(location, **tf_kwargs)
|
||||
except Exception as e:
|
||||
logger.warning(f"{self.model_name} is a no-go; {e} - Falling back to auto.")
|
||||
|
||||
# Try to determine model type from either AutoModel or falling back to legacy
|
||||
try:
|
||||
return AutoModelForCausalLM.from_pretrained(location, **tf_kwargs)
|
||||
except Exception as e:
|
||||
if "out of memory" in traceback.format_exc().lower():
|
||||
raise RuntimeError(
|
||||
@@ -284,22 +292,11 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
)
|
||||
|
||||
logger.warning(f"Fell back to GPT2LMHeadModel due to {e}")
|
||||
|
||||
try:
|
||||
return GPT2LMHeadModel.from_pretrained(
|
||||
location,
|
||||
revision=utils.koboldai_vars.revision,
|
||||
cache_dir="cache",
|
||||
**tf_kwargs,
|
||||
)
|
||||
return GPT2LMHeadModel.from_pretrained(location, **tf_kwargs)
|
||||
except Exception as e:
|
||||
logger.warning(f"Fell back to GPTNeoForCausalLM due to {e}")
|
||||
return GPTNeoForCausalLM.from_pretrained(
|
||||
location,
|
||||
revision=utils.koboldai_vars.revision,
|
||||
cache_dir="cache",
|
||||
**tf_kwargs,
|
||||
)
|
||||
return GPTNeoForCausalLM.from_pretrained(location, **tf_kwargs)
|
||||
|
||||
def get_hidden_size(self) -> int:
|
||||
return self.model.get_input_embeddings().embedding_dim
|
||||
|
Reference in New Issue
Block a user