Fix the logic of `force_download` in utils.py
This commit is contained in:
parent
c1ef20bcff
commit
4b693b4858
5
utils.py
5
utils.py
|
@ -207,6 +207,11 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
||||||
headers["authorization"] = f"Bearer {use_auth_token}"
|
headers["authorization"] = f"Bearer {use_auth_token}"
|
||||||
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
|
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
|
||||||
filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)]
|
filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)]
|
||||||
|
if force_download:
|
||||||
|
for n in filenames:
|
||||||
|
path = os.path.join(_cache_dir, n + ".json")
|
||||||
|
if os.path.exists(path):
|
||||||
|
os.remove(path)
|
||||||
aria2_config = "\n".join(f"{u}\n out={n}" for u, n in zip(urls, filenames)).encode()
|
aria2_config = "\n".join(f"{u}\n out={n}" for u, n in zip(urls, filenames)).encode()
|
||||||
with tempfile.NamedTemporaryFile("w+b", delete=True) as f:
|
with tempfile.NamedTemporaryFile("w+b", delete=True) as f:
|
||||||
f.write(aria2_config)
|
f.write(aria2_config)
|
||||||
|
|
Loading…
Reference in New Issue