mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-04-02 20:51:37 +02:00
Don't use aria2_hook()
if force_download=True
is used
This commit is contained in:
parent
903d593ce4
commit
5732a8f15a
8
utils.py
8
utils.py
@ -1,4 +1,4 @@
|
||||
from threading import Timer
|
||||
from threading import Timer, local
|
||||
import re
|
||||
import shutil
|
||||
import json
|
||||
@ -141,7 +141,9 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
||||
from huggingface_hub import HfFolder
|
||||
if shutil.which("aria2c") is None: # Don't do anything if aria2 is not installed
|
||||
return
|
||||
if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path) or transformers.modeling_utils.is_remote_url(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path + ".index"):
|
||||
if local_files_only: # If local_files_only is true, we obviously don't need to download anything
|
||||
return
|
||||
if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path + ".index") or transformers.modeling_utils.is_remote_url(pretrained_model_name_or_path):
|
||||
return
|
||||
if proxies:
|
||||
print("WARNING: KoboldAI does not support using aria2 to download models from huggingface.co through a proxy. Disabling aria2 download mode.")
|
||||
@ -179,7 +181,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
||||
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
|
||||
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
|
||||
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
|
||||
map_filename = transformers.file_utils.cached_path(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent)
|
||||
map_filename = transformers.file_utils.cached_path(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent)
|
||||
with open(map_filename) as f:
|
||||
map_data = json.load(f)
|
||||
filenames = set(map_data["weight_map"].values())
|
||||
|
Loading…
x
Reference in New Issue
Block a user