Revision Cleanup

This commit is contained in:
Henk
2023-01-31 18:46:59 +01:00
parent f5666d996f
commit f57489f73c

View File

@@ -286,7 +286,7 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
if token is None:
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
_revision = args.revision if args.revision is not None else huggingface_hub.constants.DEFAULT_REVISION
_revision = revision if revision is not None else huggingface_hub.constants.DEFAULT_REVISION
sharded = False
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
if use_auth_token:
@@ -297,7 +297,7 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
def is_cached(filename):
try:
huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, local_files_only=True, revision=_revision)
huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, local_files_only=True, revision=revision)
except ValueError:
return False
return True
@@ -306,7 +306,7 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
except AttributeError:
return
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=_revision)
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=revision)
if is_cached(filename) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
break
if sharded:
@@ -316,11 +316,11 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
map_filename = huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent)
map_filename = huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision)
with open(map_filename) as f:
map_data = json.load(f)
filenames = set(map_data["weight_map"].values())
urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=_revision) for n in filenames]
urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=revision) for n in filenames]
if not force_download:
urls = [u for u, n in zip(urls, filenames) if not is_cached(n)]
if not urls: