diff --git a/utils.py b/utils.py index 0eb1de28..6acecf3a 100644 --- a/utils.py +++ b/utils.py @@ -207,6 +207,11 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d headers["authorization"] = f"Bearer {use_auth_token}" etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]] filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)] + if force_download: + for n in filenames: + path = os.path.join(_cache_dir, n + ".json") + if os.path.exists(path): + os.remove(path) aria2_config = "\n".join(f"{u}\n out={n}" for u, n in zip(urls, filenames)).encode() with tempfile.NamedTemporaryFile("w+b", delete=True) as f: f.write(aria2_config)