Aria2 bug fix for Windows users

This commit is contained in:
Gnome Ann 2022-05-14 11:44:28 -04:00
parent 1476e76cfc
commit 6e82f205b4
1 changed files with 3 additions and 1 deletions

View File

@ -232,6 +232,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
s = requests.Session() s = requests.Session()
s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1))) s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1)))
bar = None bar = None
done = False
secret = os.urandom(17).hex() secret = os.urandom(17).hex()
try: try:
with tempfile.NamedTemporaryFile("w+b", delete=False) as f: with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
@ -246,6 +247,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
bar.n = bar.total bar.n = bar.total
bar.close() bar.close()
p.terminate() p.terminate()
done = True
break break
if bar is None: if bar is None:
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000) bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000)
@ -270,7 +272,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
except OSError: except OSError:
pass pass
code = p.wait() code = p.wait()
if code: if not done and code:
raise OSError(f"aria2 exited with exit code {code}") raise OSError(f"aria2 exited with exit code {code}")
for u, t, n in zip(urls, etags, filenames): for u, t, n in zip(urls, etags, filenames):
os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n)) os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n))