Revision Fixes Fixes
This commit is contained in:
parent
739cccd8ed
commit
257a535be5
2
utils.py
2
utils.py
|
@ -460,6 +460,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
|||
import transformers
|
||||
import transformers.modeling_utils
|
||||
from huggingface_hub import HfFolder
|
||||
_revision = args.revision if args.revision is not None else huggingface_hub.constants.DEFAULT_REVISION
|
||||
if shutil.which("aria2c") is None: # Don't do anything if aria2 is not installed
|
||||
return
|
||||
if local_files_only: # If local_files_only is true, we obviously don't need to download anything
|
||||
|
@ -555,6 +556,7 @@ def get_num_shards(filename):
|
|||
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
|
||||
import transformers.modeling_utils
|
||||
import torch
|
||||
_revision = args.revision if args.revision is not None else huggingface_hub.constants.DEFAULT_REVISION
|
||||
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=_revision)
|
||||
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))
|
||||
|
||||
|
|
Loading…
Reference in New Issue