mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Download all shards correctly on aria2 and raise on bad load key
This commit is contained in:
@@ -288,11 +288,18 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
try:
|
||||
return AutoModelForCausalLM.from_pretrained(location, **tf_kwargs)
|
||||
except Exception as e:
|
||||
if "out of memory" in traceback.format_exc().lower():
|
||||
traceback_string = traceback.format_exc().lower()
|
||||
|
||||
if "out of memory" in traceback_string:
|
||||
raise RuntimeError(
|
||||
"One of your GPUs ran out of memory when KoboldAI tried to load your model."
|
||||
)
|
||||
|
||||
# Model corrupted or serious loading problem. Stop here.
|
||||
if "invalid load key" in traceback_string:
|
||||
logger.error("Invalid load key! Aborting.")
|
||||
raise
|
||||
|
||||
logger.warning(f"Fell back to GPT2LMHeadModel due to {e}")
|
||||
try:
|
||||
return GPT2LMHeadModel.from_pretrained(location, **tf_kwargs)
|
||||
|
Reference in New Issue
Block a user