diff --git a/utils.py b/utils.py index 6acecf3a..22d35584 100644 --- a/utils.py +++ b/utils.py @@ -155,34 +155,27 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.") _cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE sharded = False + headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)} + if use_auth_token: + headers["authorization"] = f"Bearer {use_auth_token}" + def is_cached(url): + try: + transformers.file_utils.get_from_cache(url, cache_dir=cache_dir, local_files_only=True) + except FileNotFoundError: + return False + return True while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file try: filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME except AttributeError: return url = transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, filename, revision=revision, mirror=mirror) - try: - transformers.file_utils.get_from_cache(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent) - except transformers.file_utils.RepositoryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " - "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " - "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " - "login` and pass `use_auth_token=True`." - ) - except transformers.file_utils.RevisionNotFoundError: - raise EnvironmentError( - f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " - "this model name. Check the model page at " - f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." - ) - except transformers.file_utils.EntryNotFoundError: - if sharded: - return - else: - sharded = True - else: + if is_cached(url) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers): break + if sharded: + return + else: + sharded = True if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download filenames = [transformers.modeling_utils.WEIGHTS_NAME] else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it @@ -191,20 +184,11 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d map_data = json.load(f) filenames = set(map_data["weight_map"].values()) urls = [transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, n, revision=revision, mirror=mirror) for n in filenames] - def is_cached(url): - try: - transformers.file_utils.get_from_cache(url, cache_dir=cache_dir, local_files_only=True) - except FileNotFoundError: - return False - return True if not force_download: if all(is_cached(u) for u in urls): return elif local_files_only: raise FileNotFoundError("Cannot find the requested files in the cached path and outgoing traffic has been disabled. To enable model look-ups and downloads online, set 'local_files_only' to False.") - headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)} - if use_auth_token: - headers["authorization"] = f"Bearer {use_auth_token}" etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]] filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)] if force_download: @@ -215,6 +199,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d aria2_config = "\n".join(f"{u}\n out={n}" for u, n in zip(urls, filenames)).encode() with tempfile.NamedTemporaryFile("w+b", delete=True) as f: f.write(aria2_config) + f.flush() p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.PIPE, stderr=subprocess.STDOUT) for line in p.stdout: print(line.decode(), end="", flush=True)