From 257a535be59a6ebc57eabf26089bbfdd85310994 Mon Sep 17 00:00:00 2001 From: Henk Date: Tue, 31 Jan 2023 05:17:34 +0100 Subject: [PATCH] Revision Fixes Fixes --- utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/utils.py b/utils.py index 79f90b11..01c8b2a3 100644 --- a/utils.py +++ b/utils.py @@ -460,6 +460,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d import transformers import transformers.modeling_utils from huggingface_hub import HfFolder + _revision = args.revision if args.revision is not None else huggingface_hub.constants.DEFAULT_REVISION if shutil.which("aria2c") is None: # Don't do anything if aria2 is not installed return if local_files_only: # If local_files_only is true, we obviously don't need to download anything @@ -555,6 +556,7 @@ def get_num_shards(filename): def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs): import transformers.modeling_utils import torch + _revision = args.revision if args.revision is not None else huggingface_hub.constants.DEFAULT_REVISION shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=_revision) return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))