Fix the logic of `force_download` in utils.py

This commit is contained in:
Gnome Ann 2022-05-10 22:47:03 -04:00
parent c1ef20bcff
commit 4b693b4858
1 changed files with 5 additions and 0 deletions

View File

@ -207,6 +207,11 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
headers["authorization"] = f"Bearer {use_auth_token}" headers["authorization"] = f"Bearer {use_auth_token}"
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]] etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)] filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)]
if force_download:
for n in filenames:
path = os.path.join(_cache_dir, n + ".json")
if os.path.exists(path):
os.remove(path)
aria2_config = "\n".join(f"{u}\n out={n}" for u, n in zip(urls, filenames)).encode() aria2_config = "\n".join(f"{u}\n out={n}" for u, n in zip(urls, filenames)).encode()
with tempfile.NamedTemporaryFile("w+b", delete=True) as f: with tempfile.NamedTemporaryFile("w+b", delete=True) as f:
f.write(aria2_config) f.write(aria2_config)