aria2 now downloads to different filename and renames afterwards

This is to match the behaviour of the original transformers downloader
in order to deal with the rare case of someone downloading a model using
aria2, cancelling before it finishes, and then attempting to resume the
download with the normal transformers downloader.
This commit is contained in:
Gnome Ann 2022-05-11 15:45:38 -04:00
parent 7a3f865e3f
commit c65272052a
1 changed files with 11 additions and 4 deletions

View File

@ -1,4 +1,4 @@
from threading import Timer, local
from threading import Timer
import re
import shutil
import json
@ -186,13 +186,19 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
map_data = json.load(f)
filenames = set(map_data["weight_map"].values())
urls = [transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, n, revision=revision, mirror=mirror) for n in filenames]
if not force_download:
urls = [u for u in urls if not is_cached(u)]
if not urls:
return
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)]
for n in filenames:
path = os.path.join(_cache_dir, n + ".aria2") # Prevent aria2 from continuing cancelled downloads because continued downloads are usually limited to the speed of 1 connection
path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2")
if os.path.exists(path):
os.remove(path)
path = os.path.join(_cache_dir, "kai-tempfile." + n)
if os.path.exists(path):
os.remove(path)
os.remove(os.path.join(_cache_dir, n))
if force_download:
path = os.path.join(_cache_dir, n + ".json")
if os.path.exists(path):
@ -200,7 +206,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
path = os.path.join(_cache_dir, n)
if os.path.exists(path):
os.remove(path)
aria2_config = "\n".join(f"{u}\n out={n}" for u, n in zip(urls, filenames)).encode()
aria2_config = "\n".join(f"{u}\n out=kai-tempfile.{n}" for u, n in zip(urls, filenames)).encode()
with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
f.write(aria2_config)
f.flush()
@ -213,5 +219,6 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
except OSError:
pass
for u, t, n in zip(urls, etags, filenames):
os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n))
with open(os.path.join(_cache_dir, n + ".json"), "w") as f:
json.dump({"url": u, "etag": t}, f)