mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-13 10:10:54 +01:00
Also enable aria2 downloading for non-sharded checkpoints
This commit is contained in:
parent
e115bb68e4
commit
c1ef20bcff
@ -186,7 +186,7 @@ fi
|
|||||||
|
|
||||||
#Download routine for Aria2c scripts
|
#Download routine for Aria2c scripts
|
||||||
if [ ! -z ${aria2+x} ]; then
|
if [ ! -z ${aria2+x} ]; then
|
||||||
curl -L $aria2 | aria2c -c -i- -d$dloc --user-agent=KoboldAI --file-allocation=none
|
curl -L $aria2 | aria2c -x 10 -s 10 -j 10 -c -i- -d$dloc --user-agent=KoboldAI --file-allocation=none
|
||||||
fi
|
fi
|
||||||
|
|
||||||
#Extract the model with 7z
|
#Extract the model with 7z
|
||||||
|
8
utils.py
8
utils.py
@ -183,9 +183,9 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
|||||||
sharded = True
|
sharded = True
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
if not sharded: # If the model has a pytorch_model.bin file, that's the only large file to download so it's probably more efficient to just let transformers download it
|
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
|
||||||
return
|
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
|
||||||
# Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
|
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
|
||||||
map_filename = transformers.file_utils.cached_path(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent)
|
map_filename = transformers.file_utils.cached_path(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent)
|
||||||
with open(map_filename) as f:
|
with open(map_filename) as f:
|
||||||
map_data = json.load(f)
|
map_data = json.load(f)
|
||||||
@ -210,7 +210,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
|||||||
aria2_config = "\n".join(f"{u}\n out={n}" for u, n in zip(urls, filenames)).encode()
|
aria2_config = "\n".join(f"{u}\n out={n}" for u, n in zip(urls, filenames)).encode()
|
||||||
with tempfile.NamedTemporaryFile("w+b", delete=True) as f:
|
with tempfile.NamedTemporaryFile("w+b", delete=True) as f:
|
||||||
f.write(aria2_config)
|
f.write(aria2_config)
|
||||||
p = subprocess.Popen(["aria2c", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||||
for line in p.stdout:
|
for line in p.stdout:
|
||||||
print(line.decode(), end="", flush=True)
|
print(line.decode(), end="", flush=True)
|
||||||
for u, t, n in zip(urls, etags, filenames):
|
for u, t, n in zip(urls, etags, filenames):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user