diff --git a/utils.py b/utils.py index 0e4299de..bc085412 100644 --- a/utils.py +++ b/utils.py @@ -232,6 +232,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d s = requests.Session() s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1))) bar = None + done = False secret = os.urandom(17).hex() try: with tempfile.NamedTemporaryFile("w+b", delete=False) as f: @@ -246,6 +247,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d bar.n = bar.total bar.close() p.terminate() + done = True break if bar is None: bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000) @@ -270,7 +272,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d except OSError: pass code = p.wait() - if code: + if not done and code: raise OSError(f"aria2 exited with exit code {code}") for u, t, n in zip(urls, etags, filenames): os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n))