Fix the behaviour of `aria2_hook()` when using `force_download`

This commit is contained in:
Gnome Ann 2022-05-11 14:41:34 -04:00
parent 5732a8f15a
commit f60c7d8492
1 changed files with 3 additions and 0 deletions

View File

@ -198,6 +198,9 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
path = os.path.join(_cache_dir, n + ".json")
if os.path.exists(path):
os.remove(path)
path = os.path.join(_cache_dir, n)
if os.path.exists(path):
os.remove(path)
aria2_config = "\n".join(f"{u}\n out={n}" for u, n in zip(urls, filenames)).encode()
with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
f.write(aria2_config)