Don't use `aria2_hook()` if `force_download=True` is used

This commit is contained in:
Gnome Ann 2022-05-11 14:40:31 -04:00
parent 903d593ce4
commit 5732a8f15a
1 changed files with 5 additions and 3 deletions

View File

@ -1,4 +1,4 @@
from threading import Timer
from threading import Timer, local
import re
import shutil
import json
@ -141,7 +141,9 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
from huggingface_hub import HfFolder
if shutil.which("aria2c") is None: # Don't do anything if aria2 is not installed
return
if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path) or transformers.modeling_utils.is_remote_url(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path + ".index"):
if local_files_only: # If local_files_only is true, we obviously don't need to download anything
return
if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path + ".index") or transformers.modeling_utils.is_remote_url(pretrained_model_name_or_path):
return
if proxies:
print("WARNING: KoboldAI does not support using aria2 to download models from huggingface.co through a proxy. Disabling aria2 download mode.")
@ -179,7 +181,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
map_filename = transformers.file_utils.cached_path(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent)
map_filename = transformers.file_utils.cached_path(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent)
with open(map_filename) as f:
map_data = json.load(f)
filenames = set(map_data["weight_map"].values())