mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
More loading fixes
This commit is contained in:
@@ -270,13 +270,21 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _get_model(self, location: str, tf_kwargs: Dict):
|
def _get_model(self, location: str, tf_kwargs: Dict):
|
||||||
|
tf_kwargs["revision"] = utils.koboldai_vars.revision
|
||||||
|
tf_kwargs["cache_dir"] = "cache"
|
||||||
|
|
||||||
|
# If we have model hints for legacy model, use them rather than fall back.
|
||||||
try:
|
try:
|
||||||
return AutoModelForCausalLM.from_pretrained(
|
if self.model_name == "GPT2Custom":
|
||||||
location,
|
return GPT2LMHeadModel.from_pretrained(location, **tf_kwargs)
|
||||||
revision=utils.koboldai_vars.revision,
|
elif self.model_name == "NeoCustom":
|
||||||
cache_dir="cache",
|
return GPTNeoForCausalLM.from_pretrained(location, **tf_kwargs)
|
||||||
**tf_kwargs,
|
except Exception as e:
|
||||||
)
|
logger.warning(f"{self.model_name} is a no-go; {e} - Falling back to auto.")
|
||||||
|
|
||||||
|
# Try to determine model type from either AutoModel or falling back to legacy
|
||||||
|
try:
|
||||||
|
return AutoModelForCausalLM.from_pretrained(location, **tf_kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "out of memory" in traceback.format_exc().lower():
|
if "out of memory" in traceback.format_exc().lower():
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@@ -284,22 +292,11 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.warning(f"Fell back to GPT2LMHeadModel due to {e}")
|
logger.warning(f"Fell back to GPT2LMHeadModel due to {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return GPT2LMHeadModel.from_pretrained(
|
return GPT2LMHeadModel.from_pretrained(location, **tf_kwargs)
|
||||||
location,
|
|
||||||
revision=utils.koboldai_vars.revision,
|
|
||||||
cache_dir="cache",
|
|
||||||
**tf_kwargs,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Fell back to GPTNeoForCausalLM due to {e}")
|
logger.warning(f"Fell back to GPTNeoForCausalLM due to {e}")
|
||||||
return GPTNeoForCausalLM.from_pretrained(
|
return GPTNeoForCausalLM.from_pretrained(location, **tf_kwargs)
|
||||||
location,
|
|
||||||
revision=utils.koboldai_vars.revision,
|
|
||||||
cache_dir="cache",
|
|
||||||
**tf_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_hidden_size(self) -> int:
|
def get_hidden_size(self) -> int:
|
||||||
return self.model.get_input_embeddings().embedding_dim
|
return self.model.get_input_embeddings().embedding_dim
|
||||||
|
Reference in New Issue
Block a user