Merge pull request #209 from VE-FORBRYDERNE/dependency-fix

Fix compatibility issues with transformers and optax/chex
This commit is contained in:
henk717 2022-09-15 23:48:13 +02:00 committed by GitHub
commit 77763da6e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 317 additions and 92 deletions

View File

@ -1695,11 +1695,11 @@ def patch_transformers_download():
pass
def http_get(
url: str,
temp_file: transformers.utils.hub.BinaryIO,
temp_file,
proxies=None,
resume_size=0,
headers: transformers.utils.hub.Optional[transformers.utils.hub.Dict[str, str]] = None,
file_name: transformers.utils.hub.Optional[str] = None,
headers=None,
file_name=None,
):
"""
Download remote file. Do not gobble up errors.
@ -2527,11 +2527,13 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
import transformers.configuration_utils
import transformers.modeling_utils
import transformers.file_utils
import huggingface_hub
legacy = packaging.version.parse(transformers_version) < packaging.version.parse("4.22.0.dev0")
# Save the config.json
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, transformers.configuration_utils.CONFIG_NAME, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.configuration_utils.CONFIG_NAME))
shutil.move(os.path.realpath(huggingface_hub.hf_hub_download(vars.model, transformers.configuration_utils.CONFIG_NAME, revision=vars.revision, cache_dir="cache", local_files_only=True, legacy_cache_layout=legacy)), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.configuration_utils.CONFIG_NAME))
if(utils.num_shards is None):
# Save the pytorch_model.bin of an unsharded model
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, transformers.modeling_utils.WEIGHTS_NAME, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.modeling_utils.WEIGHTS_NAME))
shutil.move(os.path.realpath(huggingface_hub.hf_hub_download(vars.model, transformers.modeling_utils.WEIGHTS_NAME, revision=vars.revision, cache_dir="cache", local_files_only=True, legacy_cache_layout=legacy)), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.modeling_utils.WEIGHTS_NAME))
else:
with open(utils.from_pretrained_index_filename) as f:
map_data = json.load(f)
@ -2540,7 +2542,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
shutil.move(utils.from_pretrained_index_filename, os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.modeling_utils.WEIGHTS_INDEX_NAME))
# Then save the pytorch_model-#####-of-#####.bin files
for filename in filenames:
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, filename, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), filename))
shutil.move(os.path.realpath(huggingface_hub.hf_hub_download(vars.model, filename, revision=vars.revision, cache_dir="cache", local_files_only=True, legacy_cache_layout=legacy)), os.path.join("models/{}".format(vars.model.replace('/', '_')), filename))
shutil.rmtree("cache/")
if(vars.badwordsids is vars.badwordsids_default and vars.model_type not in ("gpt2", "gpt_neo", "gptj")):

View File

@ -23,6 +23,6 @@ dependencies:
- flask-cloudflared
- flask-ngrok
- lupa==1.10
- transformers==4.21.3
- transformers>=4.20.1
- accelerate
- loguru

View File

@ -23,5 +23,5 @@ dependencies:
- flask-cloudflared
- flask-ngrok
- lupa==1.10
- transformers==4.21.3
- transformers>=4.20.1
- accelerate

View File

@ -1,4 +1,4 @@
transformers==4.21.3
transformers>=4.20.1
Flask
Flask-SocketIO
requests

View File

@ -2,11 +2,10 @@ torch >= 1.9, <= 1.11
numpy
tqdm
requests
optax >= 0.0.5, <= 0.0.9
dm-haiku == 0.0.5
jax == 0.2.21
jaxlib >= 0.1.69, <= 0.3.7
transformers == 4.21.3
transformers >= 4.20.1
progressbar2
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck
flask
@ -17,7 +16,6 @@ eventlet
lupa==1.10
markdown
bleach==4.1.0
chex==0.1.4
flask-session
marshmallow>=3.13
apispec-webframeworks

View File

@ -30,7 +30,7 @@ SOFTWARE.
import utils
import multiprocessing
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, TypeVar
import progressbar
import time
import os
@ -45,7 +45,6 @@ from jax.config import config
from jax.experimental import maps
import jax.numpy as jnp
import numpy as np
import optax
import haiku as hk
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
from tokenizers import Tokenizer
@ -136,6 +135,14 @@ def __batch_xmap(shard_dim=1):
return inner
class _EmptyState(NamedTuple):
pass
class _DummyOptimizer:
def init(*args, **kwargs):
return _EmptyState()
def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
'''
This gets called by generate_loop_fn to apply repetition penalty
@ -1167,7 +1174,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]
params["optimizer"] = optax.scale(0)
params["optimizer"] = _DummyOptimizer()
mesh_shape = (1, cores_per_replica)
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())

372
utils.py
View File

@ -4,6 +4,7 @@ import shutil
import json
import subprocess
import tempfile
from urllib.error import HTTPError
import requests
import requests.adapters
import time
@ -13,6 +14,10 @@ import packaging.version
from tqdm.auto import tqdm
import os
import itertools
import hashlib
import huggingface_hub
import packaging.version
from pathlib import Path
from typing import List, Optional
HAS_ACCELERATE = packaging.version.parse(transformers_version) >= packaging.version.parse("4.20.0.dev0")
@ -182,81 +187,9 @@ class Send_to_socketio(object):
except:
pass
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
def _download_with_aria2(aria2_config: str, total_length: int, directory: str = ".", user_agent=None, force_download=False, use_auth_token=None):
import transformers
import transformers.modeling_utils
from huggingface_hub import HfFolder
if shutil.which("aria2c") is None: # Don't do anything if aria2 is not installed
return
if local_files_only: # If local_files_only is true, we obviously don't need to download anything
return
if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path + ".index") or transformers.modeling_utils.is_remote_url(pretrained_model_name_or_path):
return
if proxies:
print("WARNING: KoboldAI does not support using aria2 to download models from huggingface.co through a proxy. Disabling aria2 download mode.")
return
if use_auth_token:
if isinstance(use_auth_token, str):
token = use_auth_token
else:
token = HfFolder.get_token()
if token is None:
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
sharded = False
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
if use_auth_token:
headers["authorization"] = f"Bearer {use_auth_token}"
def is_cached(url):
try:
transformers.file_utils.get_from_cache(url, cache_dir=cache_dir, local_files_only=True)
except (FileNotFoundError, transformers.file_utils.EntryNotFoundError):
return False
return True
while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file
try:
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
except AttributeError:
return
url = transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, filename, revision=revision, mirror=mirror)
if is_cached(url) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
break
if sharded:
return
else:
sharded = True
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
map_filename = transformers.file_utils.cached_path(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent)
with open(map_filename) as f:
map_data = json.load(f)
filenames = set(map_data["weight_map"].values())
urls = [transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, n, revision=revision, mirror=mirror) for n in filenames]
if not force_download:
urls = [u for u in urls if not is_cached(u)]
if not urls:
return
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)]
for n in filenames:
path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2")
if os.path.exists(path):
os.remove(path)
path = os.path.join(_cache_dir, "kai-tempfile." + n)
if os.path.exists(path):
os.remove(path)
if force_download:
path = os.path.join(_cache_dir, n + ".json")
if os.path.exists(path):
os.remove(path)
path = os.path.join(_cache_dir, n)
if os.path.exists(path):
os.remove(path)
total_length = sum(int(h["Content-Length"]) for h in headers)
lengths = {}
aria2_config = "\n".join(f"{u}\n out=kai-tempfile.{n}" for u, n in zip(urls, filenames)).encode()
s = requests.Session()
s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1)))
bar = None
@ -266,7 +199,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
f.write(aria2_config)
f.flush()
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(vars.aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(vars.aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", directory, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {use_auth_token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
while p.poll() is None:
r = s.post(f"http://localhost:{vars.aria2_port}/jsonrpc", json={"jsonrpc": "2.0", "id": "kai", "method": "aria2.tellActive", "params": [f"token:{secret}"]}).json()["result"]
if not r:
@ -278,7 +211,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
done = True
break
if bar is None:
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000, file=Send_to_socketio())
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000)
visited = set()
for x in r:
filename = x["files"][0]["path"]
@ -302,6 +235,291 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
code = p.wait()
if not done and code:
raise OSError(f"aria2 exited with exit code {code}")
def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
import transformers
import transformers.modeling_utils
from huggingface_hub import HfFolder
if use_auth_token:
if isinstance(use_auth_token, str):
token = use_auth_token
else:
token = HfFolder.get_token()
if token is None:
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
_revision = revision if revision is not None else huggingface_hub.constants.DEFAULT_REVISION
sharded = False
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
if use_auth_token:
headers["authorization"] = f"Bearer {use_auth_token}"
storage_folder = os.path.join(_cache_dir, huggingface_hub.file_download.repo_folder_name(repo_id=pretrained_model_name_or_path, repo_type="model"))
os.makedirs(storage_folder, exist_ok=True)
def is_cached(filename):
try:
huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, local_files_only=True)
except ValueError:
return False
return True
while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file
try:
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
except AttributeError:
return
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=revision)
if is_cached(filename) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
break
if sharded:
return
else:
sharded = True
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
map_filename = huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent)
with open(map_filename) as f:
map_data = json.load(f)
filenames = set(map_data["weight_map"].values())
urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=revision) for n in filenames]
if not force_download:
urls = [u for u, n in zip(urls, filenames) if not is_cached(n)]
if not urls:
return
blob_paths = []
# This section is a modified version of hf_hub_download from huggingface_hub
# See https://github.com/huggingface/huggingface_hub/blob/main/LICENSE for license
for u, n in zip(urls, filenames):
relative_filename = os.path.join(*n.split("/"))
if not local_files_only:
try:
r = huggingface_hub.file_download._request_wrapper(
method="HEAD",
url=u,
headers=headers,
allow_redirects=False,
follow_relative_redirects=True,
proxies=proxies,
timeout=10,
)
try:
r.raise_for_status()
except HTTPError as e:
error_code = r.headers.get("X-Error-Code")
if error_code != "EntryNotFound":
raise RuntimeError(f"HEAD {u} failed with error code {r.status_code}")
commit_hash = r.headers.get(huggingface_hub.file_download.HUGGINGFACE_HEADER_X_REPO_COMMIT)
if commit_hash is not None:
no_exist_file_path = (
Path(storage_folder)
/ ".no_exist"
/ commit_hash
/ relative_filename
)
no_exist_file_path.parent.mkdir(parents=True, exist_ok=True)
no_exist_file_path.touch()
huggingface_hub.file_download._cache_commit_hash_for_specific_revision(
storage_folder, _revision, commit_hash
)
raise
commit_hash = r.headers[huggingface_hub.file_download.HUGGINGFACE_HEADER_X_REPO_COMMIT]
if commit_hash is None:
raise OSError(
"Distant resource does not seem to be on huggingface.co (missing"
" commit header)."
)
etag = r.headers.get(huggingface_hub.file_download.HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get(
"ETag"
)
# We favor a custom header indicating the etag of the linked resource, and
# we fallback to the regular etag header.
# If we don't have any of those, raise an error.
if etag is None:
raise OSError(
"Distant resource does not have an ETag, we won't be able to"
" reliably ensure reproducibility."
)
etag = huggingface_hub.file_download._normalize_etag(etag)
# In case of a redirect, save an extra redirect on the request.get call,
# and ensure we download the exact atomic version even if it changed
# between the HEAD and the GET (unlikely, but hey).
# Useful for lfs blobs that are stored on a CDN.
if 300 <= r.status_code <= 399:
url_to_download = r.headers["Location"]
if (
"lfs.huggingface.co" in url_to_download
or "lfs-staging.huggingface.co" in url_to_download
):
# Remove authorization header when downloading a LFS blob
headers.pop("authorization", None)
except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
# Actually raise for those subclasses of ConnectionError
raise
except (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
huggingface_hub.file_download.OfflineModeIsEnabled,
):
# Otherwise, our Internet connection is down.
# etag is None
pass
if etag is None:
# In those cases, we cannot force download.
if force_download:
raise ValueError(
"We have no connection or you passed local_files_only, so"
" force_download is not an accepted option."
)
if huggingface_hub.file_download.REGEX_COMMIT_HASH.match(_revision):
commit_hash = _revision
else:
ref_path = os.path.join(storage_folder, "refs", _revision)
with open(ref_path) as f:
commit_hash = f.read()
pointer_path = os.path.join(
storage_folder, "snapshots", commit_hash, relative_filename
)
if os.path.exists(pointer_path):
return pointer_path
# If we couldn't find an appropriate file on disk,
# raise an error.
# If files cannot be found and local_files_only=True,
# the models might've been found if local_files_only=False
# Notify the user about that
if local_files_only:
raise huggingface_hub.file_download.LocalEntryNotFoundError(
"Cannot find the requested files in the disk cache and"
" outgoing traffic has been disabled. To enable hf.co look-ups"
" and downloads online, set 'local_files_only' to False."
)
else:
raise huggingface_hub.file_download.LocalEntryNotFoundError(
"Connection error, and we cannot find the requested files in"
" the disk cache. Please try again or make sure your Internet"
" connection is on."
)
# From now on, etag and commit_hash are not None.
blob_path = os.path.join(storage_folder, "blobs", etag)
pointer_path = os.path.join(
storage_folder, "snapshots", commit_hash, relative_filename
)
os.makedirs(os.path.dirname(blob_path), exist_ok=True)
os.makedirs(os.path.dirname(pointer_path), exist_ok=True)
# if passed revision is not identical to commit_hash
# then revision has to be a branch name or tag name.
# In that case store a ref.
huggingface_hub.file_download._cache_commit_hash_for_specific_revision(storage_folder, _revision, commit_hash)
if os.path.exists(pointer_path) and not force_download:
return pointer_path
if os.path.exists(blob_path) and not force_download:
# we have the blob already, but not the pointer
huggingface_hub.file_download.logger.info("creating pointer to %s from %s", blob_path, pointer_path)
huggingface_hub.file_download._create_relative_symlink(blob_path, pointer_path)
return pointer_path
# Some Windows versions do not allow for paths longer than 255 characters.
# In this case, we must specify it is an extended path by using the "\\?\" prefix.
if os.name == "nt" and len(os.path.abspath(blob_path)) > 255:
blob_path = "\\\\?\\" + os.path.abspath(blob_path)
blob_paths.append(blob_path)
filenames = blob_paths
headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
for n in filenames:
prefix, suffix = n.rsplit("/", 1)
path = os.path.join(prefix, "kai-tempfile." + suffix + ".aria2")
if os.path.exists(path):
os.remove(path)
path = os.path.join(prefix, "kai-tempfile." + suffix)
if os.path.exists(path):
os.remove(path)
total_length = sum(int(h["Content-Length"]) for h in headers)
aria2_config = "\n".join(f"{u}\n out={os.path.join(prefix, 'kai-tempfile.' + suffix)}" for u, n in zip(urls, filenames) for prefix, suffix in [n.rsplit("/", 1)]).encode()
_download_with_aria2(aria2_config, total_length, use_auth_token=token if use_auth_token else None, user_agent=user_agent, force_download=force_download)
for u, n in zip(urls, filenames):
prefix, suffix = n.rsplit("/", 1)
os.rename(os.path.join(prefix, "kai-tempfile." + suffix), os.path.join(prefix, suffix))
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
import transformers
import transformers.modeling_utils
from huggingface_hub import HfFolder
if shutil.which("aria2c") is None: # Don't do anything if aria2 is not installed
return
if local_files_only: # If local_files_only is true, we obviously don't need to download anything
return
if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path + ".index") or transformers.modeling_utils.is_remote_url(pretrained_model_name_or_path):
return
if proxies:
print("WARNING: KoboldAI does not support using aria2 to download models from huggingface.co through a proxy. Disabling aria2 download mode.")
return
if packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.22.0.dev0"):
return _transformers22_aria2_hook(pretrained_model_name_or_path, force_download=force_download, cache_dir=cache_dir, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, **kwargs)
if use_auth_token:
if isinstance(use_auth_token, str):
token = use_auth_token
else:
token = HfFolder.get_token()
if token is None:
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
sharded = False
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
if use_auth_token:
headers["authorization"] = f"Bearer {use_auth_token}"
def is_cached(url):
try:
huggingface_hub.cached_download(url, cache_dir=cache_dir, local_files_only=True)
except ValueError:
return False
return True
while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file
try:
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
except AttributeError:
return
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=revision)
if is_cached(url) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
break
if sharded:
return
else:
sharded = True
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
map_filename = huggingface_hub.cached_download(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent)
with open(map_filename) as f:
map_data = json.load(f)
filenames = set(map_data["weight_map"].values())
urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=revision) for n in filenames]
if not force_download:
urls = [u for u in urls if not is_cached(u)]
if not urls:
return
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
filenames = [hashlib.sha256(u.encode("utf-8")).hexdigest() + "." + hashlib.sha256(t.encode("utf-8")).hexdigest() for u, t in zip(urls, etags)]
for n in filenames:
path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2")
if os.path.exists(path):
os.remove(path)
path = os.path.join(_cache_dir, "kai-tempfile." + n)
if os.path.exists(path):
os.remove(path)
if force_download:
path = os.path.join(_cache_dir, n + ".json")
if os.path.exists(path):
os.remove(path)
path = os.path.join(_cache_dir, n)
if os.path.exists(path):
os.remove(path)
total_length = sum(int(h["Content-Length"]) for h in headers)
aria2_config = "\n".join(f"{u}\n out=kai-tempfile.{n}" for u, n in zip(urls, filenames)).encode()
_download_with_aria2(aria2_config, total_length, directory=_cache_dir, use_auth_token=token if use_auth_token else None, user_agent=user_agent, force_download=force_download)
for u, t, n in zip(urls, etags, filenames):
os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n))
with open(os.path.join(_cache_dir, n + ".json"), "w") as f:
@ -321,10 +539,10 @@ def get_num_shards(filename):
# pytorch_model.bin.index.json, returns a list of weight names in the
# sharded model. Requires lazy loader to be enabled to work properl
#==================================================================#
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
import transformers.modeling_utils
import torch
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror)
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision)
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))
#==================================================================#