Better Aria2 Defaults
Trunc prevents slow allocation on windows, force_download=True has proven a more reliable default. Since models are converted to local formats it does not impact local users. And because -c is used the impact of checking if the model is correct is desirable and minimal.
This commit is contained in:
parent
903d593ce4
commit
6d27084e8a
4
utils.py
4
utils.py
|
@ -135,7 +135,7 @@ def decodenewlines(txt):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Downloads sharded huggingface checkpoints using aria2c if possible
|
# Downloads sharded huggingface checkpoints using aria2c if possible
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
|
def aria2_hook(pretrained_model_name_or_path: str, force_download=True, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
|
||||||
import transformers
|
import transformers
|
||||||
import transformers.modeling_utils
|
import transformers.modeling_utils
|
||||||
from huggingface_hub import HfFolder
|
from huggingface_hub import HfFolder
|
||||||
|
@ -200,7 +200,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
||||||
with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
|
with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
|
||||||
f.write(aria2_config)
|
f.write(aria2_config)
|
||||||
f.flush()
|
f.flush()
|
||||||
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--disable-ipv6", "--file-allocation=none", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--disable-ipv6", "--file-allocation=trunc", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||||
for line in p.stdout:
|
for line in p.stdout:
|
||||||
print(line.decode(), end="", flush=True)
|
print(line.decode(), end="", flush=True)
|
||||||
path = f.name
|
path = f.name
|
||||||
|
|
Loading…
Reference in New Issue