diff --git a/utils.py b/utils.py index 22d35584..38066ed0 100644 --- a/utils.py +++ b/utils.py @@ -197,12 +197,17 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d if os.path.exists(path): os.remove(path) aria2_config = "\n".join(f"{u}\n out={n}" for u, n in zip(urls, filenames)).encode() - with tempfile.NamedTemporaryFile("w+b", delete=True) as f: + with tempfile.NamedTemporaryFile("w+b", delete=False) as f: f.write(aria2_config) f.flush() - p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--disable-ipv6", "--file-allocation=none", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.PIPE, stderr=subprocess.STDOUT) for line in p.stdout: print(line.decode(), end="", flush=True) + path = f.name + try: + os.remove(path) + except OSError: + pass for u, t, n in zip(urls, etags, filenames): with open(os.path.join(_cache_dir, n + ".json"), "w") as f: json.dump({"url": u, "etag": t}, f)