Better Aria2 Defaults

Trunc prevents slow allocation on windows, force_download=True has proven a more reliable default. Since models are converted to local formats it does not impact local users. And because -c is used the impact of checking if the model is correct is desirable and minimal.
This commit is contained in:
Henk 2022-05-11 21:38:33 +02:00
parent 903d593ce4
commit 6d27084e8a
1 changed files with 2 additions and 2 deletions

View File

@ -135,7 +135,7 @@ def decodenewlines(txt):
#==================================================================# #==================================================================#
# Downloads sharded huggingface checkpoints using aria2c if possible # Downloads sharded huggingface checkpoints using aria2c if possible
#==================================================================# #==================================================================#
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs): def aria2_hook(pretrained_model_name_or_path: str, force_download=True, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
import transformers import transformers
import transformers.modeling_utils import transformers.modeling_utils
from huggingface_hub import HfFolder from huggingface_hub import HfFolder
@ -200,7 +200,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
with tempfile.NamedTemporaryFile("w+b", delete=False) as f: with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
f.write(aria2_config) f.write(aria2_config)
f.flush() f.flush()
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--disable-ipv6", "--file-allocation=none", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.PIPE, stderr=subprocess.STDOUT) p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--disable-ipv6", "--file-allocation=trunc", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
for line in p.stdout: for line in p.stdout:
print(line.decode(), end="", flush=True) print(line.decode(), end="", flush=True)
path = f.name path = f.name