Merge pull request #126 from VE-FORBRYDERNE/aria2

Aria2 downloader bug fixes
This commit is contained in:
henk717 2022-05-11 21:58:31 +02:00 committed by GitHub
commit 05549de42d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 10 deletions

View File

@ -135,13 +135,15 @@ def decodenewlines(txt):
#==================================================================# #==================================================================#
# Downloads sharded huggingface checkpoints using aria2c if possible # Downloads sharded huggingface checkpoints using aria2c if possible
#==================================================================# #==================================================================#
def aria2_hook(pretrained_model_name_or_path: str, force_download=True, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs): def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
import transformers import transformers
import transformers.modeling_utils import transformers.modeling_utils
from huggingface_hub import HfFolder from huggingface_hub import HfFolder
if shutil.which("aria2c") is None: # Don't do anything if aria2 is not installed if shutil.which("aria2c") is None: # Don't do anything if aria2 is not installed
return return
if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path) or transformers.modeling_utils.is_remote_url(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path + ".index"): if local_files_only: # If local_files_only is true, we obviously don't need to download anything
return
if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path + ".index") or transformers.modeling_utils.is_remote_url(pretrained_model_name_or_path):
return return
if proxies: if proxies:
print("WARNING: KoboldAI does not support using aria2 to download models from huggingface.co through a proxy. Disabling aria2 download mode.") print("WARNING: KoboldAI does not support using aria2 to download models from huggingface.co through a proxy. Disabling aria2 download mode.")
@ -179,28 +181,36 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=True, cache_di
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
filenames = [transformers.modeling_utils.WEIGHTS_NAME] filenames = [transformers.modeling_utils.WEIGHTS_NAME]
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
map_filename = transformers.file_utils.cached_path(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent) map_filename = transformers.file_utils.cached_path(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent)
with open(map_filename) as f: with open(map_filename) as f:
map_data = json.load(f) map_data = json.load(f)
filenames = set(map_data["weight_map"].values()) filenames = set(map_data["weight_map"].values())
urls = [transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, n, revision=revision, mirror=mirror) for n in filenames] urls = [transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, n, revision=revision, mirror=mirror) for n in filenames]
if not force_download: if not force_download:
if all(is_cached(u) for u in urls): urls = [u for u in urls if not is_cached(u)]
if not urls:
return return
elif local_files_only:
raise FileNotFoundError("Cannot find the requested files in the cached path and outgoing traffic has been disabled. To enable model look-ups and downloads online, set 'local_files_only' to False.")
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]] etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)] filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)]
if force_download: for n in filenames:
for n in filenames: path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2")
if os.path.exists(path):
os.remove(path)
path = os.path.join(_cache_dir, "kai-tempfile." + n)
if os.path.exists(path):
os.remove(path)
if force_download:
path = os.path.join(_cache_dir, n + ".json") path = os.path.join(_cache_dir, n + ".json")
if os.path.exists(path): if os.path.exists(path):
os.remove(path) os.remove(path)
aria2_config = "\n".join(f"{u}\n out={n}" for u, n in zip(urls, filenames)).encode() path = os.path.join(_cache_dir, n)
if os.path.exists(path):
os.remove(path)
aria2_config = "\n".join(f"{u}\n out=kai-tempfile.{n}" for u, n in zip(urls, filenames)).encode()
with tempfile.NamedTemporaryFile("w+b", delete=False) as f: with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
f.write(aria2_config) f.write(aria2_config)
f.flush() f.flush()
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--disable-ipv6", "--file-allocation=trunc", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.PIPE, stderr=subprocess.STDOUT) p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming", "false", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
for line in p.stdout: for line in p.stdout:
print(line.decode(), end="", flush=True) print(line.decode(), end="", flush=True)
path = f.name path = f.name
@ -209,5 +219,6 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=True, cache_di
except OSError: except OSError:
pass pass
for u, t, n in zip(urls, etags, filenames): for u, t, n in zip(urls, etags, filenames):
os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n))
with open(os.path.join(_cache_dir, n + ".json"), "w") as f: with open(os.path.join(_cache_dir, n + ".json"), "w") as f:
json.dump({"url": u, "etag": t}, f) json.dump({"url": u, "etag": t}, f)