Merge branch 'main' into united

This commit is contained in:
Henk
2023-02-18 18:14:03 +01:00
3 changed files with 23 additions and 7 deletions

View File

@@ -286,7 +286,7 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
if token is None:
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
_revision = revision if revision is not None else huggingface_hub.constants.DEFAULT_REVISION
_revision = args.revision if args.revision is not None else huggingface_hub.constants.DEFAULT_REVISION
sharded = False
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
if use_auth_token:
@@ -306,7 +306,7 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
except AttributeError:
return
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=revision)
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=_revision)
if is_cached(filename) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
break
if sharded:
@@ -320,7 +320,7 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
with open(map_filename) as f:
map_data = json.load(f)
filenames = set(map_data["weight_map"].values())
urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=revision) for n in filenames]
urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=_revision) for n in filenames]
if not force_download:
urls = [u for u, n in zip(urls, filenames) if not is_cached(n)]
if not urls:
@@ -485,6 +485,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
import transformers
import transformers.modeling_utils
from huggingface_hub import HfFolder
_revision = args.revision if args.revision is not None else huggingface_hub.constants.DEFAULT_REVISION
if shutil.which("aria2c") is None: # Don't do anything if aria2 is not installed
return
if local_files_only: # If local_files_only is true, we obviously don't need to download anything
@@ -519,7 +520,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
except AttributeError:
return
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=revision)
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=_revision)
if is_cached(url) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
break
if sharded:
@@ -533,7 +534,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
with open(map_filename) as f:
map_data = json.load(f)
filenames = set(map_data["weight_map"].values())
urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=revision) for n in filenames]
urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=_revision) for n in filenames]
if not force_download:
urls = [u for u in urls if not is_cached(u)]
if not urls:
@@ -580,7 +581,8 @@ def get_num_shards(filename):
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
import transformers.modeling_utils
import torch
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision)
_revision = args.revision if args.revision is not None else huggingface_hub.constants.DEFAULT_REVISION
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=_revision)
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))
#==================================================================#