Download all shards correctly on aria2 and raise on bad load key

This commit is contained in:
somebody
2023-05-01 18:53:36 -05:00
parent 933dbd634a
commit 97e84928ba
2 changed files with 13 additions and 2 deletions

View File

@@ -288,11 +288,18 @@ class HFTorchInferenceModel(HFInferenceModel):
try: try:
return AutoModelForCausalLM.from_pretrained(location, **tf_kwargs) return AutoModelForCausalLM.from_pretrained(location, **tf_kwargs)
except Exception as e: except Exception as e:
if "out of memory" in traceback.format_exc().lower(): traceback_string = traceback.format_exc().lower()
if "out of memory" in traceback_string:
raise RuntimeError( raise RuntimeError(
"One of your GPUs ran out of memory when KoboldAI tried to load your model." "One of your GPUs ran out of memory when KoboldAI tried to load your model."
) )
# Model corrupted or serious loading problem. Stop here.
if "invalid load key" in traceback_string:
logger.error("Invalid load key! Aborting.")
raise
logger.warning(f"Fell back to GPT2LMHeadModel due to {e}") logger.warning(f"Fell back to GPT2LMHeadModel due to {e}")
try: try:
return GPT2LMHeadModel.from_pretrained(location, **tf_kwargs) return GPT2LMHeadModel.from_pretrained(location, **tf_kwargs)

View File

@@ -285,6 +285,7 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
# one of these out in the wild yet, probably due to how Safetensors has a # one of these out in the wild yet, probably due to how Safetensors has a
# lot of benifits of sharding built in # lot of benifits of sharding built in
for possible_filename in [ for possible_filename in [
transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME,
transformers.modeling_utils.SAFE_WEIGHTS_NAME, transformers.modeling_utils.SAFE_WEIGHTS_NAME,
transformers.modeling_utils.WEIGHTS_INDEX_NAME, transformers.modeling_utils.WEIGHTS_INDEX_NAME,
transformers.modeling_utils.WEIGHTS_NAME transformers.modeling_utils.WEIGHTS_NAME
@@ -299,7 +300,10 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
if not filename: if not filename:
return return
if filename not in [transformers.modeling_utils.WEIGHTS_INDEX_NAME]: if filename not in [
transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME,
transformers.modeling_utils.WEIGHTS_INDEX_NAME
]:
# If the model isn't sharded, theres only one file to download # If the model isn't sharded, theres only one file to download
filenames = [filename] filenames = [filename]
else: else: