Revision Fixes Fixes

This commit is contained in:
Henk 2023-01-31 05:17:34 +01:00
parent 739cccd8ed
commit 257a535be5
1 changed files with 2 additions and 0 deletions

View File

@ -460,6 +460,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
import transformers
import transformers.modeling_utils
from huggingface_hub import HfFolder
_revision = args.revision if args.revision is not None else huggingface_hub.constants.DEFAULT_REVISION
if shutil.which("aria2c") is None: # Don't do anything if aria2 is not installed
return
if local_files_only: # If local_files_only is true, we obviously don't need to download anything
@ -555,6 +556,7 @@ def get_num_shards(filename):
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
import transformers.modeling_utils
import torch
_revision = args.revision if args.revision is not None else huggingface_hub.constants.DEFAULT_REVISION
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=_revision)
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))