aria2 now downloads to different filename and renames afterwards
This is to match the behaviour of the original transformers downloader in order to deal with the rare case of someone downloading a model using aria2, cancelling before it finishes, and then attempting to resume the download with the normal transformers downloader.
This commit is contained in:
parent
7a3f865e3f
commit
c65272052a
15
utils.py
15
utils.py
|
@ -1,4 +1,4 @@
|
|||
from threading import Timer, local
|
||||
from threading import Timer
|
||||
import re
|
||||
import shutil
|
||||
import json
|
||||
|
@ -186,13 +186,19 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
|||
map_data = json.load(f)
|
||||
filenames = set(map_data["weight_map"].values())
|
||||
urls = [transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, n, revision=revision, mirror=mirror) for n in filenames]
|
||||
if not force_download:
|
||||
urls = [u for u in urls if not is_cached(u)]
|
||||
if not urls:
|
||||
return
|
||||
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
|
||||
filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)]
|
||||
for n in filenames:
|
||||
path = os.path.join(_cache_dir, n + ".aria2") # Prevent aria2 from continuing cancelled downloads because continued downloads are usually limited to the speed of 1 connection
|
||||
path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2")
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
path = os.path.join(_cache_dir, "kai-tempfile." + n)
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
os.remove(os.path.join(_cache_dir, n))
|
||||
if force_download:
|
||||
path = os.path.join(_cache_dir, n + ".json")
|
||||
if os.path.exists(path):
|
||||
|
@ -200,7 +206,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
|||
path = os.path.join(_cache_dir, n)
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
aria2_config = "\n".join(f"{u}\n out={n}" for u, n in zip(urls, filenames)).encode()
|
||||
aria2_config = "\n".join(f"{u}\n out=kai-tempfile.{n}" for u, n in zip(urls, filenames)).encode()
|
||||
with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
|
||||
f.write(aria2_config)
|
||||
f.flush()
|
||||
|
@ -213,5 +219,6 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
|||
except OSError:
|
||||
pass
|
||||
for u, t, n in zip(urls, etags, filenames):
|
||||
os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n))
|
||||
with open(os.path.join(_cache_dir, n + ".json"), "w") as f:
|
||||
json.dump({"url": u, "etag": t}, f)
|
||||
|
|
Loading…
Reference in New Issue