Download all shards correctly on aria2 and raise on bad load key

This commit is contained in:
somebody
2023-05-01 18:53:36 -05:00
parent 933dbd634a
commit 97e84928ba
2 changed files with 13 additions and 2 deletions

View File

@@ -288,11 +288,18 @@ class HFTorchInferenceModel(HFInferenceModel):
try:
return AutoModelForCausalLM.from_pretrained(location, **tf_kwargs)
except Exception as e:
if "out of memory" in traceback.format_exc().lower():
traceback_string = traceback.format_exc().lower()
if "out of memory" in traceback_string:
raise RuntimeError(
"One of your GPUs ran out of memory when KoboldAI tried to load your model."
)
# Model corrupted or serious loading problem. Stop here.
if "invalid load key" in traceback_string:
logger.error("Invalid load key! Aborting.")
raise
logger.warning(f"Fell back to GPT2LMHeadModel due to {e}")
try:
return GPT2LMHeadModel.from_pretrained(location, **tf_kwargs)