mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Use aria2 to download split checkpoints
This commit is contained in:
14
aiserver.py
14
aiserver.py
@ -1111,6 +1111,13 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||||||
import transformers.generation_utils
|
import transformers.generation_utils
|
||||||
from transformers import __version__ as transformers_version
|
from transformers import __version__ as transformers_version
|
||||||
|
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
old_from_pretrained = PreTrainedModel.from_pretrained
|
||||||
|
def new_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||||
|
return old_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||||
|
|
||||||
# Lazy loader
|
# Lazy loader
|
||||||
import torch_lazy_loader
|
import torch_lazy_loader
|
||||||
def get_lazy_load_callback(n_layers, convert_to_float16=True):
|
def get_lazy_load_callback(n_layers, convert_to_float16=True):
|
||||||
@ -1535,6 +1542,13 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
||||||
else:
|
else:
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
old_from_pretrained = PreTrainedModel.from_pretrained
|
||||||
|
def new_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||||
|
return old_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||||
|
|
||||||
def tpumtjgetsofttokens():
|
def tpumtjgetsofttokens():
|
||||||
soft_tokens = None
|
soft_tokens = None
|
||||||
if(vars.sp is None):
|
if(vars.sp is None):
|
||||||
|
@ -162,7 +162,7 @@ if [ "$init" != "skip" ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# Make sure Colab has the system dependencies
|
# Make sure Colab has the system dependencies
|
||||||
sudo apt install netbase -y
|
sudo apt install netbase aria2 -y
|
||||||
npm install -g localtunnel
|
npm install -g localtunnel
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@ -186,7 +186,6 @@ fi
|
|||||||
|
|
||||||
#Download routine for Aria2c scripts
|
#Download routine for Aria2c scripts
|
||||||
if [ ! -z ${aria2+x} ]; then
|
if [ ! -z ${aria2+x} ]; then
|
||||||
apt install aria2 -y
|
|
||||||
curl -L $aria2 | aria2c -c -i- -d$dloc --user-agent=KoboldAI --file-allocation=none
|
curl -L $aria2 | aria2c -c -i- -d$dloc --user-agent=KoboldAI --file-allocation=none
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
91
utils.py
91
utils.py
@ -1,5 +1,11 @@
|
|||||||
from threading import Timer
|
from threading import Timer
|
||||||
import re
|
import re
|
||||||
|
import shutil
|
||||||
|
import json
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import requests
|
||||||
|
import os
|
||||||
|
|
||||||
vars = None
|
vars = None
|
||||||
|
|
||||||
@ -125,3 +131,88 @@ def decodenewlines(txt):
|
|||||||
if(vars.newlinemode == "s"):
|
if(vars.newlinemode == "s"):
|
||||||
return txt.replace("</s>", '\n')
|
return txt.replace("</s>", '\n')
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Downloads sharded huggingface checkpoints using aria2c if possible
|
||||||
|
#==================================================================#
|
||||||
|
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
|
||||||
|
import transformers
|
||||||
|
import transformers.modeling_utils
|
||||||
|
from huggingface_hub import HfFolder
|
||||||
|
if shutil.which("aria2c") is None: # Don't do anything if aria2 is not installed
|
||||||
|
return
|
||||||
|
if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path) or transformers.modeling_utils.is_remote_url(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path + ".index"):
|
||||||
|
return
|
||||||
|
if proxies:
|
||||||
|
print("WARNING: KoboldAI does not support using aria2 to download models from huggingface.co through a proxy. Disabling aria2 download mode.")
|
||||||
|
return
|
||||||
|
if use_auth_token:
|
||||||
|
if isinstance(use_auth_token, str):
|
||||||
|
token = use_auth_token
|
||||||
|
else:
|
||||||
|
token = HfFolder.get_token()
|
||||||
|
if token is None:
|
||||||
|
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
||||||
|
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
|
||||||
|
sharded = False
|
||||||
|
while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file
|
||||||
|
try:
|
||||||
|
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
|
||||||
|
except AttributeError:
|
||||||
|
return
|
||||||
|
url = transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, filename, revision=revision, mirror=mirror)
|
||||||
|
try:
|
||||||
|
transformers.file_utils.get_from_cache(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent)
|
||||||
|
except transformers.file_utils.RepositoryNotFoundError:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||||
|
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||||
|
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||||
|
"login` and pass `use_auth_token=True`."
|
||||||
|
)
|
||||||
|
except transformers.file_utils.RevisionNotFoundError:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||||
|
"this model name. Check the model page at "
|
||||||
|
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||||
|
)
|
||||||
|
except transformers.file_utils.EntryNotFoundError:
|
||||||
|
if sharded:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
sharded = True
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
if not sharded: # If the model has a pytorch_model.bin file, that's the only large file to download so it's probably more efficient to just let transformers download it
|
||||||
|
return
|
||||||
|
# Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
|
||||||
|
map_filename = transformers.file_utils.cached_path(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent)
|
||||||
|
with open(map_filename) as f:
|
||||||
|
map_data = json.load(f)
|
||||||
|
filenames = set(map_data["weight_map"].values())
|
||||||
|
urls = [transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, n, revision=revision, mirror=mirror) for n in filenames]
|
||||||
|
def is_cached(url):
|
||||||
|
try:
|
||||||
|
transformers.file_utils.get_from_cache(url, cache_dir=cache_dir, local_files_only=True)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
if not force_download:
|
||||||
|
if all(is_cached(u) for u in urls):
|
||||||
|
return
|
||||||
|
elif local_files_only:
|
||||||
|
raise FileNotFoundError("Cannot find the requested files in the cached path and outgoing traffic has been disabled. To enable model look-ups and downloads online, set 'local_files_only' to False.")
|
||||||
|
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
|
||||||
|
if use_auth_token:
|
||||||
|
headers["authorization"] = f"Bearer {use_auth_token}"
|
||||||
|
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
|
||||||
|
filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)]
|
||||||
|
aria2_config = "\n".join(f"{u}\n out={n}" for u, n in zip(urls, filenames)).encode()
|
||||||
|
with tempfile.NamedTemporaryFile("w+b", delete=True) as f:
|
||||||
|
f.write(aria2_config)
|
||||||
|
p = subprocess.Popen(["aria2c", "-d", _cache_dir, "-i", f.name] + (["-c"] if not force_download else []) + (["-U", str(user_agent)] if user_agent is not None else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||||
|
for line in p.stdout:
|
||||||
|
print(line.decode(), end="", flush=True)
|
||||||
|
for u, t, n in zip(urls, etags, filenames):
|
||||||
|
with open(os.path.join(_cache_dir, n + ".json"), "w") as f:
|
||||||
|
json.dump({"url": u, "etag": t}, f)
|
||||||
|
Reference in New Issue
Block a user