Download all shards correctly on aria2 and raise on bad load key

This commit is contained in:
somebody
2023-05-01 18:53:36 -05:00
parent 933dbd634a
commit 97e84928ba
2 changed files with 13 additions and 2 deletions

View File

@@ -285,6 +285,7 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
# one of these out in the wild yet, probably due to how Safetensors has a
# lot of benifits of sharding built in
for possible_filename in [
transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME,
transformers.modeling_utils.SAFE_WEIGHTS_NAME,
transformers.modeling_utils.WEIGHTS_INDEX_NAME,
transformers.modeling_utils.WEIGHTS_NAME
@@ -299,7 +300,10 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
if not filename:
return
if filename not in [transformers.modeling_utils.WEIGHTS_INDEX_NAME]:
if filename not in [
transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME,
transformers.modeling_utils.WEIGHTS_INDEX_NAME
]:
# If the model isn't sharded, theres only one file to download
filenames = [filename]
else: