From 5732a8f15a9af40f58c13a741b14cd72a0811d90 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Wed, 11 May 2022 14:40:31 -0400 Subject: [PATCH] Don't use `aria2_hook()` if `force_download=True` is used --- utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/utils.py b/utils.py index 38066ed0..7295c677 100644 --- a/utils.py +++ b/utils.py @@ -1,4 +1,4 @@ -from threading import Timer +from threading import Timer, local import re import shutil import json @@ -141,7 +141,9 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d from huggingface_hub import HfFolder if shutil.which("aria2c") is None: # Don't do anything if aria2 is not installed return - if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path) or transformers.modeling_utils.is_remote_url(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path + ".index"): + if local_files_only: # If local_files_only is true, we obviously don't need to download anything + return + if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path + ".index") or transformers.modeling_utils.is_remote_url(pretrained_model_name_or_path): return if proxies: print("WARNING: KoboldAI does not support using aria2 to download models from huggingface.co through a proxy. Disabling aria2 download mode.") @@ -179,7 +181,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download filenames = [transformers.modeling_utils.WEIGHTS_NAME] else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it - map_filename = transformers.file_utils.cached_path(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent) + map_filename = transformers.file_utils.cached_path(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent) with open(map_filename) as f: map_data = json.load(f) filenames = set(map_data["weight_map"].values())