Compare commits
5 Commits
aac999c073
...
bf3a838cf1
Author | SHA1 | Date |
---|---|---|
henk717 | bf3a838cf1 | |
vfbd | d55d8232d0 | |
vfbd | 463bf86bcc | |
vfbd | 7bf6c9a23f | |
vfbd | 551565c5ac |
|
@ -1713,11 +1713,13 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||||
import transformers.configuration_utils
|
import transformers.configuration_utils
|
||||||
import transformers.modeling_utils
|
import transformers.modeling_utils
|
||||||
import transformers.file_utils
|
import transformers.file_utils
|
||||||
|
import huggingface_hub
|
||||||
|
legacy = packaging.version.parse(transformers_version) < packaging.version.parse("4.22.0.dev0")
|
||||||
# Save the config.json
|
# Save the config.json
|
||||||
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, transformers.configuration_utils.CONFIG_NAME, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.configuration_utils.CONFIG_NAME))
|
shutil.move(os.path.realpath(huggingface_hub.hf_hub_download(vars.model, transformers.configuration_utils.CONFIG_NAME, revision=vars.revision, cache_dir="cache", local_files_only=True, legacy_cache_layout=legacy)), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.configuration_utils.CONFIG_NAME))
|
||||||
if(utils.num_shards is None):
|
if(utils.num_shards is None):
|
||||||
# Save the pytorch_model.bin of an unsharded model
|
# Save the pytorch_model.bin of an unsharded model
|
||||||
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, transformers.modeling_utils.WEIGHTS_NAME, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.modeling_utils.WEIGHTS_NAME))
|
shutil.move(os.path.realpath(huggingface_hub.hf_hub_download(vars.model, transformers.modeling_utils.WEIGHTS_NAME, revision=vars.revision, cache_dir="cache", local_files_only=True, legacy_cache_layout=legacy)), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.modeling_utils.WEIGHTS_NAME))
|
||||||
else:
|
else:
|
||||||
with open(utils.from_pretrained_index_filename) as f:
|
with open(utils.from_pretrained_index_filename) as f:
|
||||||
map_data = json.load(f)
|
map_data = json.load(f)
|
||||||
|
@ -1726,7 +1728,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||||
shutil.move(utils.from_pretrained_index_filename, os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.modeling_utils.WEIGHTS_INDEX_NAME))
|
shutil.move(utils.from_pretrained_index_filename, os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.modeling_utils.WEIGHTS_INDEX_NAME))
|
||||||
# Then save the pytorch_model-#####-of-#####.bin files
|
# Then save the pytorch_model-#####-of-#####.bin files
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, filename, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), filename))
|
shutil.move(os.path.realpath(huggingface_hub.hf_hub_download(vars.model, filename, revision=vars.revision, cache_dir="cache", local_files_only=True, legacy_cache_layout=legacy)), os.path.join("models/{}".format(vars.model.replace('/', '_')), filename))
|
||||||
shutil.rmtree("cache/")
|
shutil.rmtree("cache/")
|
||||||
|
|
||||||
if(vars.hascuda):
|
if(vars.hascuda):
|
||||||
|
|
|
@ -20,5 +20,5 @@ dependencies:
|
||||||
- flask-cloudflared
|
- flask-cloudflared
|
||||||
- flask-ngrok
|
- flask-ngrok
|
||||||
- lupa==1.10
|
- lupa==1.10
|
||||||
- transformers==4.21.3
|
- transformers>=4.20.1
|
||||||
- accelerate
|
- accelerate
|
|
@ -20,5 +20,5 @@ dependencies:
|
||||||
- flask-cloudflared
|
- flask-cloudflared
|
||||||
- flask-ngrok
|
- flask-ngrok
|
||||||
- lupa==1.10
|
- lupa==1.10
|
||||||
- transformers==4.21.3
|
- transformers>=4.20.1
|
||||||
- accelerate
|
- accelerate
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
transformers==4.21.3
|
transformers>=4.20.1
|
||||||
Flask
|
Flask
|
||||||
Flask-SocketIO
|
Flask-SocketIO
|
||||||
requests
|
requests
|
||||||
|
@ -11,4 +11,4 @@ markdown
|
||||||
bleach==4.1.0
|
bleach==4.1.0
|
||||||
sentencepiece
|
sentencepiece
|
||||||
protobuf
|
protobuf
|
||||||
accelerate
|
accelerate
|
||||||
|
|
|
@ -2,11 +2,10 @@ torch >= 1.9, <= 1.11
|
||||||
numpy
|
numpy
|
||||||
tqdm
|
tqdm
|
||||||
requests
|
requests
|
||||||
optax >= 0.0.5, <= 0.0.9
|
|
||||||
dm-haiku == 0.0.5
|
dm-haiku == 0.0.5
|
||||||
jax == 0.2.21
|
jax == 0.2.21
|
||||||
jaxlib >= 0.1.69, <= 0.3.7
|
jaxlib >= 0.1.69, <= 0.3.7
|
||||||
transformers == 4.21.3
|
transformers >= 4.20.1
|
||||||
progressbar2
|
progressbar2
|
||||||
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck
|
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck
|
||||||
flask
|
flask
|
||||||
|
@ -17,4 +16,3 @@ eventlet
|
||||||
lupa==1.10
|
lupa==1.10
|
||||||
markdown
|
markdown
|
||||||
bleach==4.1.0
|
bleach==4.1.0
|
||||||
chex==0.1.4
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ SOFTWARE.
|
||||||
import utils
|
import utils
|
||||||
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
|
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, TypeVar
|
||||||
import progressbar
|
import progressbar
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
|
@ -45,7 +45,6 @@ from jax.config import config
|
||||||
from jax.experimental import maps
|
from jax.experimental import maps
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import optax
|
|
||||||
import haiku as hk
|
import haiku as hk
|
||||||
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
|
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
|
||||||
from tokenizers import Tokenizer
|
from tokenizers import Tokenizer
|
||||||
|
@ -120,6 +119,14 @@ def __batch_xmap(shard_dim=1):
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
class _EmptyState(NamedTuple):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class _DummyOptimizer:
|
||||||
|
def init(*args, **kwargs):
|
||||||
|
return _EmptyState()
|
||||||
|
|
||||||
|
|
||||||
def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
|
def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
|
||||||
'''
|
'''
|
||||||
This gets called by generate_loop_fn to apply repetition penalty
|
This gets called by generate_loop_fn to apply repetition penalty
|
||||||
|
@ -1148,7 +1155,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||||
|
|
||||||
cores_per_replica = params["cores_per_replica"]
|
cores_per_replica = params["cores_per_replica"]
|
||||||
seq = params["seq"]
|
seq = params["seq"]
|
||||||
params["optimizer"] = optax.scale(0)
|
params["optimizer"] = _DummyOptimizer()
|
||||||
mesh_shape = (1, cores_per_replica)
|
mesh_shape = (1, cores_per_replica)
|
||||||
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
|
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
|
||||||
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
|
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
|
||||||
|
|
370
utils.py
370
utils.py
|
@ -4,12 +4,17 @@ import shutil
|
||||||
import json
|
import json
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from urllib.error import HTTPError
|
||||||
import requests
|
import requests
|
||||||
import requests.adapters
|
import requests.adapters
|
||||||
import time
|
import time
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
import os
|
import os
|
||||||
import itertools
|
import itertools
|
||||||
|
import hashlib
|
||||||
|
import huggingface_hub
|
||||||
|
import packaging.version
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
vars = None
|
vars = None
|
||||||
|
@ -159,81 +164,9 @@ def num_layers(config):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Downloads huggingface checkpoints using aria2c if possible
|
# Downloads huggingface checkpoints using aria2c if possible
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
|
def _download_with_aria2(aria2_config: str, total_length: int, directory: str = ".", user_agent=None, force_download=False, use_auth_token=None):
|
||||||
import transformers
|
import transformers
|
||||||
import transformers.modeling_utils
|
|
||||||
from huggingface_hub import HfFolder
|
|
||||||
if shutil.which("aria2c") is None: # Don't do anything if aria2 is not installed
|
|
||||||
return
|
|
||||||
if local_files_only: # If local_files_only is true, we obviously don't need to download anything
|
|
||||||
return
|
|
||||||
if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path + ".index") or transformers.modeling_utils.is_remote_url(pretrained_model_name_or_path):
|
|
||||||
return
|
|
||||||
if proxies:
|
|
||||||
print("WARNING: KoboldAI does not support using aria2 to download models from huggingface.co through a proxy. Disabling aria2 download mode.")
|
|
||||||
return
|
|
||||||
if use_auth_token:
|
|
||||||
if isinstance(use_auth_token, str):
|
|
||||||
token = use_auth_token
|
|
||||||
else:
|
|
||||||
token = HfFolder.get_token()
|
|
||||||
if token is None:
|
|
||||||
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
|
||||||
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
|
|
||||||
sharded = False
|
|
||||||
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
|
|
||||||
if use_auth_token:
|
|
||||||
headers["authorization"] = f"Bearer {use_auth_token}"
|
|
||||||
def is_cached(url):
|
|
||||||
try:
|
|
||||||
transformers.file_utils.get_from_cache(url, cache_dir=cache_dir, local_files_only=True)
|
|
||||||
except (FileNotFoundError, transformers.file_utils.EntryNotFoundError):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file
|
|
||||||
try:
|
|
||||||
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
|
|
||||||
except AttributeError:
|
|
||||||
return
|
|
||||||
url = transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, filename, revision=revision, mirror=mirror)
|
|
||||||
if is_cached(url) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
|
|
||||||
break
|
|
||||||
if sharded:
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
sharded = True
|
|
||||||
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
|
|
||||||
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
|
|
||||||
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
|
|
||||||
map_filename = transformers.file_utils.cached_path(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent)
|
|
||||||
with open(map_filename) as f:
|
|
||||||
map_data = json.load(f)
|
|
||||||
filenames = set(map_data["weight_map"].values())
|
|
||||||
urls = [transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, n, revision=revision, mirror=mirror) for n in filenames]
|
|
||||||
if not force_download:
|
|
||||||
urls = [u for u in urls if not is_cached(u)]
|
|
||||||
if not urls:
|
|
||||||
return
|
|
||||||
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
|
|
||||||
headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
|
|
||||||
filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)]
|
|
||||||
for n in filenames:
|
|
||||||
path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2")
|
|
||||||
if os.path.exists(path):
|
|
||||||
os.remove(path)
|
|
||||||
path = os.path.join(_cache_dir, "kai-tempfile." + n)
|
|
||||||
if os.path.exists(path):
|
|
||||||
os.remove(path)
|
|
||||||
if force_download:
|
|
||||||
path = os.path.join(_cache_dir, n + ".json")
|
|
||||||
if os.path.exists(path):
|
|
||||||
os.remove(path)
|
|
||||||
path = os.path.join(_cache_dir, n)
|
|
||||||
if os.path.exists(path):
|
|
||||||
os.remove(path)
|
|
||||||
total_length = sum(int(h["Content-Length"]) for h in headers)
|
|
||||||
lengths = {}
|
lengths = {}
|
||||||
aria2_config = "\n".join(f"{u}\n out=kai-tempfile.{n}" for u, n in zip(urls, filenames)).encode()
|
|
||||||
s = requests.Session()
|
s = requests.Session()
|
||||||
s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1)))
|
s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1)))
|
||||||
bar = None
|
bar = None
|
||||||
|
@ -243,7 +176,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
||||||
with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
|
with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
|
||||||
f.write(aria2_config)
|
f.write(aria2_config)
|
||||||
f.flush()
|
f.flush()
|
||||||
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(vars.aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(vars.aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", directory, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {use_auth_token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||||
while p.poll() is None:
|
while p.poll() is None:
|
||||||
r = s.post(f"http://localhost:{vars.aria2_port}/jsonrpc", json={"jsonrpc": "2.0", "id": "kai", "method": "aria2.tellActive", "params": [f"token:{secret}"]}).json()["result"]
|
r = s.post(f"http://localhost:{vars.aria2_port}/jsonrpc", json={"jsonrpc": "2.0", "id": "kai", "method": "aria2.tellActive", "params": [f"token:{secret}"]}).json()["result"]
|
||||||
if not r:
|
if not r:
|
||||||
|
@ -279,6 +212,291 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
||||||
code = p.wait()
|
code = p.wait()
|
||||||
if not done and code:
|
if not done and code:
|
||||||
raise OSError(f"aria2 exited with exit code {code}")
|
raise OSError(f"aria2 exited with exit code {code}")
|
||||||
|
|
||||||
|
def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
|
||||||
|
import transformers
|
||||||
|
import transformers.modeling_utils
|
||||||
|
from huggingface_hub import HfFolder
|
||||||
|
if use_auth_token:
|
||||||
|
if isinstance(use_auth_token, str):
|
||||||
|
token = use_auth_token
|
||||||
|
else:
|
||||||
|
token = HfFolder.get_token()
|
||||||
|
if token is None:
|
||||||
|
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
||||||
|
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
|
||||||
|
_revision = revision if revision is not None else huggingface_hub.constants.DEFAULT_REVISION
|
||||||
|
sharded = False
|
||||||
|
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
|
||||||
|
if use_auth_token:
|
||||||
|
headers["authorization"] = f"Bearer {use_auth_token}"
|
||||||
|
|
||||||
|
storage_folder = os.path.join(_cache_dir, huggingface_hub.file_download.repo_folder_name(repo_id=pretrained_model_name_or_path, repo_type="model"))
|
||||||
|
os.makedirs(storage_folder, exist_ok=True)
|
||||||
|
|
||||||
|
def is_cached(filename):
|
||||||
|
try:
|
||||||
|
huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, local_files_only=True)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file
|
||||||
|
try:
|
||||||
|
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
|
||||||
|
except AttributeError:
|
||||||
|
return
|
||||||
|
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=revision)
|
||||||
|
if is_cached(filename) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
|
||||||
|
break
|
||||||
|
if sharded:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
sharded = True
|
||||||
|
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
|
||||||
|
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
|
||||||
|
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
|
||||||
|
map_filename = huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent)
|
||||||
|
with open(map_filename) as f:
|
||||||
|
map_data = json.load(f)
|
||||||
|
filenames = set(map_data["weight_map"].values())
|
||||||
|
urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=revision) for n in filenames]
|
||||||
|
if not force_download:
|
||||||
|
urls = [u for u, n in zip(urls, filenames) if not is_cached(n)]
|
||||||
|
if not urls:
|
||||||
|
return
|
||||||
|
|
||||||
|
blob_paths = []
|
||||||
|
|
||||||
|
# This section is a modified version of hf_hub_download from huggingface_hub
|
||||||
|
# See https://github.com/huggingface/huggingface_hub/blob/main/LICENSE for license
|
||||||
|
for u, n in zip(urls, filenames):
|
||||||
|
relative_filename = os.path.join(*n.split("/"))
|
||||||
|
if not local_files_only:
|
||||||
|
try:
|
||||||
|
r = huggingface_hub.file_download._request_wrapper(
|
||||||
|
method="HEAD",
|
||||||
|
url=u,
|
||||||
|
headers=headers,
|
||||||
|
allow_redirects=False,
|
||||||
|
follow_relative_redirects=True,
|
||||||
|
proxies=proxies,
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
r.raise_for_status()
|
||||||
|
except HTTPError as e:
|
||||||
|
error_code = r.headers.get("X-Error-Code")
|
||||||
|
if error_code != "EntryNotFound":
|
||||||
|
raise RuntimeError(f"HEAD {u} failed with error code {r.status_code}")
|
||||||
|
commit_hash = r.headers.get(huggingface_hub.file_download.HUGGINGFACE_HEADER_X_REPO_COMMIT)
|
||||||
|
if commit_hash is not None:
|
||||||
|
no_exist_file_path = (
|
||||||
|
Path(storage_folder)
|
||||||
|
/ ".no_exist"
|
||||||
|
/ commit_hash
|
||||||
|
/ relative_filename
|
||||||
|
)
|
||||||
|
no_exist_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
no_exist_file_path.touch()
|
||||||
|
huggingface_hub.file_download._cache_commit_hash_for_specific_revision(
|
||||||
|
storage_folder, _revision, commit_hash
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
commit_hash = r.headers[huggingface_hub.file_download.HUGGINGFACE_HEADER_X_REPO_COMMIT]
|
||||||
|
if commit_hash is None:
|
||||||
|
raise OSError(
|
||||||
|
"Distant resource does not seem to be on huggingface.co (missing"
|
||||||
|
" commit header)."
|
||||||
|
)
|
||||||
|
etag = r.headers.get(huggingface_hub.file_download.HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get(
|
||||||
|
"ETag"
|
||||||
|
)
|
||||||
|
# We favor a custom header indicating the etag of the linked resource, and
|
||||||
|
# we fallback to the regular etag header.
|
||||||
|
# If we don't have any of those, raise an error.
|
||||||
|
if etag is None:
|
||||||
|
raise OSError(
|
||||||
|
"Distant resource does not have an ETag, we won't be able to"
|
||||||
|
" reliably ensure reproducibility."
|
||||||
|
)
|
||||||
|
etag = huggingface_hub.file_download._normalize_etag(etag)
|
||||||
|
# In case of a redirect, save an extra redirect on the request.get call,
|
||||||
|
# and ensure we download the exact atomic version even if it changed
|
||||||
|
# between the HEAD and the GET (unlikely, but hey).
|
||||||
|
# Useful for lfs blobs that are stored on a CDN.
|
||||||
|
if 300 <= r.status_code <= 399:
|
||||||
|
url_to_download = r.headers["Location"]
|
||||||
|
if (
|
||||||
|
"lfs.huggingface.co" in url_to_download
|
||||||
|
or "lfs-staging.huggingface.co" in url_to_download
|
||||||
|
):
|
||||||
|
# Remove authorization header when downloading a LFS blob
|
||||||
|
headers.pop("authorization", None)
|
||||||
|
except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
|
||||||
|
# Actually raise for those subclasses of ConnectionError
|
||||||
|
raise
|
||||||
|
except (
|
||||||
|
requests.exceptions.ConnectionError,
|
||||||
|
requests.exceptions.Timeout,
|
||||||
|
huggingface_hub.file_download.OfflineModeIsEnabled,
|
||||||
|
):
|
||||||
|
# Otherwise, our Internet connection is down.
|
||||||
|
# etag is None
|
||||||
|
pass
|
||||||
|
if etag is None:
|
||||||
|
# In those cases, we cannot force download.
|
||||||
|
if force_download:
|
||||||
|
raise ValueError(
|
||||||
|
"We have no connection or you passed local_files_only, so"
|
||||||
|
" force_download is not an accepted option."
|
||||||
|
)
|
||||||
|
if huggingface_hub.file_download.REGEX_COMMIT_HASH.match(_revision):
|
||||||
|
commit_hash = _revision
|
||||||
|
else:
|
||||||
|
ref_path = os.path.join(storage_folder, "refs", _revision)
|
||||||
|
with open(ref_path) as f:
|
||||||
|
commit_hash = f.read()
|
||||||
|
pointer_path = os.path.join(
|
||||||
|
storage_folder, "snapshots", commit_hash, relative_filename
|
||||||
|
)
|
||||||
|
if os.path.exists(pointer_path):
|
||||||
|
return pointer_path
|
||||||
|
# If we couldn't find an appropriate file on disk,
|
||||||
|
# raise an error.
|
||||||
|
# If files cannot be found and local_files_only=True,
|
||||||
|
# the models might've been found if local_files_only=False
|
||||||
|
# Notify the user about that
|
||||||
|
if local_files_only:
|
||||||
|
raise huggingface_hub.file_download.LocalEntryNotFoundError(
|
||||||
|
"Cannot find the requested files in the disk cache and"
|
||||||
|
" outgoing traffic has been disabled. To enable hf.co look-ups"
|
||||||
|
" and downloads online, set 'local_files_only' to False."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise huggingface_hub.file_download.LocalEntryNotFoundError(
|
||||||
|
"Connection error, and we cannot find the requested files in"
|
||||||
|
" the disk cache. Please try again or make sure your Internet"
|
||||||
|
" connection is on."
|
||||||
|
)
|
||||||
|
# From now on, etag and commit_hash are not None.
|
||||||
|
blob_path = os.path.join(storage_folder, "blobs", etag)
|
||||||
|
pointer_path = os.path.join(
|
||||||
|
storage_folder, "snapshots", commit_hash, relative_filename
|
||||||
|
)
|
||||||
|
os.makedirs(os.path.dirname(blob_path), exist_ok=True)
|
||||||
|
os.makedirs(os.path.dirname(pointer_path), exist_ok=True)
|
||||||
|
# if passed revision is not identical to commit_hash
|
||||||
|
# then revision has to be a branch name or tag name.
|
||||||
|
# In that case store a ref.
|
||||||
|
huggingface_hub.file_download._cache_commit_hash_for_specific_revision(storage_folder, _revision, commit_hash)
|
||||||
|
if os.path.exists(pointer_path) and not force_download:
|
||||||
|
return pointer_path
|
||||||
|
if os.path.exists(blob_path) and not force_download:
|
||||||
|
# we have the blob already, but not the pointer
|
||||||
|
huggingface_hub.file_download.logger.info("creating pointer to %s from %s", blob_path, pointer_path)
|
||||||
|
huggingface_hub.file_download._create_relative_symlink(blob_path, pointer_path)
|
||||||
|
return pointer_path
|
||||||
|
# Some Windows versions do not allow for paths longer than 255 characters.
|
||||||
|
# In this case, we must specify it is an extended path by using the "\\?\" prefix.
|
||||||
|
if os.name == "nt" and len(os.path.abspath(blob_path)) > 255:
|
||||||
|
blob_path = "\\\\?\\" + os.path.abspath(blob_path)
|
||||||
|
blob_paths.append(blob_path)
|
||||||
|
|
||||||
|
filenames = blob_paths
|
||||||
|
headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
|
||||||
|
|
||||||
|
for n in filenames:
|
||||||
|
prefix, suffix = n.rsplit("/", 1)
|
||||||
|
path = os.path.join(prefix, "kai-tempfile." + suffix + ".aria2")
|
||||||
|
if os.path.exists(path):
|
||||||
|
os.remove(path)
|
||||||
|
path = os.path.join(prefix, "kai-tempfile." + suffix)
|
||||||
|
if os.path.exists(path):
|
||||||
|
os.remove(path)
|
||||||
|
total_length = sum(int(h["Content-Length"]) for h in headers)
|
||||||
|
aria2_config = "\n".join(f"{u}\n out={os.path.join(prefix, 'kai-tempfile.' + suffix)}" for u, n in zip(urls, filenames) for prefix, suffix in [n.rsplit("/", 1)]).encode()
|
||||||
|
_download_with_aria2(aria2_config, total_length, use_auth_token=token if use_auth_token else None, user_agent=user_agent, force_download=force_download)
|
||||||
|
for u, n in zip(urls, filenames):
|
||||||
|
prefix, suffix = n.rsplit("/", 1)
|
||||||
|
os.rename(os.path.join(prefix, "kai-tempfile." + suffix), os.path.join(prefix, suffix))
|
||||||
|
|
||||||
|
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
|
||||||
|
import transformers
|
||||||
|
import transformers.modeling_utils
|
||||||
|
from huggingface_hub import HfFolder
|
||||||
|
if shutil.which("aria2c") is None: # Don't do anything if aria2 is not installed
|
||||||
|
return
|
||||||
|
if local_files_only: # If local_files_only is true, we obviously don't need to download anything
|
||||||
|
return
|
||||||
|
if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path + ".index") or transformers.modeling_utils.is_remote_url(pretrained_model_name_or_path):
|
||||||
|
return
|
||||||
|
if proxies:
|
||||||
|
print("WARNING: KoboldAI does not support using aria2 to download models from huggingface.co through a proxy. Disabling aria2 download mode.")
|
||||||
|
return
|
||||||
|
if packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.22.0.dev0"):
|
||||||
|
return _transformers22_aria2_hook(pretrained_model_name_or_path, force_download=force_download, cache_dir=cache_dir, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, **kwargs)
|
||||||
|
if use_auth_token:
|
||||||
|
if isinstance(use_auth_token, str):
|
||||||
|
token = use_auth_token
|
||||||
|
else:
|
||||||
|
token = HfFolder.get_token()
|
||||||
|
if token is None:
|
||||||
|
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
||||||
|
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
|
||||||
|
sharded = False
|
||||||
|
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
|
||||||
|
if use_auth_token:
|
||||||
|
headers["authorization"] = f"Bearer {use_auth_token}"
|
||||||
|
def is_cached(url):
|
||||||
|
try:
|
||||||
|
huggingface_hub.cached_download(url, cache_dir=cache_dir, local_files_only=True)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file
|
||||||
|
try:
|
||||||
|
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
|
||||||
|
except AttributeError:
|
||||||
|
return
|
||||||
|
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=revision)
|
||||||
|
if is_cached(url) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
|
||||||
|
break
|
||||||
|
if sharded:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
sharded = True
|
||||||
|
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
|
||||||
|
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
|
||||||
|
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
|
||||||
|
map_filename = huggingface_hub.cached_download(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent)
|
||||||
|
with open(map_filename) as f:
|
||||||
|
map_data = json.load(f)
|
||||||
|
filenames = set(map_data["weight_map"].values())
|
||||||
|
urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=revision) for n in filenames]
|
||||||
|
if not force_download:
|
||||||
|
urls = [u for u in urls if not is_cached(u)]
|
||||||
|
if not urls:
|
||||||
|
return
|
||||||
|
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
|
||||||
|
headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
|
||||||
|
filenames = [hashlib.sha256(u.encode("utf-8")).hexdigest() + "." + hashlib.sha256(t.encode("utf-8")).hexdigest() for u, t in zip(urls, etags)]
|
||||||
|
for n in filenames:
|
||||||
|
path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2")
|
||||||
|
if os.path.exists(path):
|
||||||
|
os.remove(path)
|
||||||
|
path = os.path.join(_cache_dir, "kai-tempfile." + n)
|
||||||
|
if os.path.exists(path):
|
||||||
|
os.remove(path)
|
||||||
|
if force_download:
|
||||||
|
path = os.path.join(_cache_dir, n + ".json")
|
||||||
|
if os.path.exists(path):
|
||||||
|
os.remove(path)
|
||||||
|
path = os.path.join(_cache_dir, n)
|
||||||
|
if os.path.exists(path):
|
||||||
|
os.remove(path)
|
||||||
|
total_length = sum(int(h["Content-Length"]) for h in headers)
|
||||||
|
aria2_config = "\n".join(f"{u}\n out=kai-tempfile.{n}" for u, n in zip(urls, filenames)).encode()
|
||||||
|
_download_with_aria2(aria2_config, total_length, directory=_cache_dir, use_auth_token=token if use_auth_token else None, user_agent=user_agent, force_download=force_download)
|
||||||
for u, t, n in zip(urls, etags, filenames):
|
for u, t, n in zip(urls, etags, filenames):
|
||||||
os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n))
|
os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n))
|
||||||
with open(os.path.join(_cache_dir, n + ".json"), "w") as f:
|
with open(os.path.join(_cache_dir, n + ".json"), "w") as f:
|
||||||
|
@ -298,8 +516,8 @@ def get_num_shards(filename):
|
||||||
# pytorch_model.bin.index.json, returns a list of weight names in the
|
# pytorch_model.bin.index.json, returns a list of weight names in the
|
||||||
# sharded model. Requires lazy loader to be enabled to work properl
|
# sharded model. Requires lazy loader to be enabled to work properl
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
|
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
|
||||||
import transformers.modeling_utils
|
import transformers.modeling_utils
|
||||||
import torch
|
import torch
|
||||||
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror)
|
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision)
|
||||||
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))
|
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))
|
||||||
|
|
Loading…
Reference in New Issue