Prevent aria2 from resuming cancelled downloads

Resumed downloads tend to be very slow.

The original transformers downloader didn't allow resuming downloads
either.
This commit is contained in:
Gnome Ann 2022-05-11 15:14:37 -04:00
parent c81f3bd084
commit 7a3f865e3f
1 changed files with 7 additions and 3 deletions

View File

@ -188,8 +188,12 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
urls = [transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, n, revision=revision, mirror=mirror) for n in filenames]
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)]
if force_download:
for n in filenames:
for n in filenames:
path = os.path.join(_cache_dir, n + ".aria2") # Prevent aria2 from continuing cancelled downloads because continued downloads are usually limited to the speed of 1 connection
if os.path.exists(path):
os.remove(path)
os.remove(os.path.join(_cache_dir, n))
if force_download:
path = os.path.join(_cache_dir, n + ".json")
if os.path.exists(path):
os.remove(path)
@ -200,7 +204,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
f.write(aria2_config)
f.flush()
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--disable-ipv6", "--file-allocation=trunc", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming", "false", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
for line in p.stdout:
print(line.decode(), end="", flush=True)
path = f.name