From 551565c5ac5d8572e9b8005a6c1798fcfa038b18 Mon Sep 17 00:00:00 2001 From: vfbd Date: Thu, 15 Sep 2022 13:37:50 -0400 Subject: [PATCH 1/4] Fix error in aria2_hook when transformers version is at least 4.22.0 Some of the transformers.file_utils functions that were removed in transformers v4.22.0 have equivalents in the huggingface_hub module. --- requirements.txt | 4 ++-- requirements_mtj.txt | 2 +- utils.py | 20 +++++++++++--------- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/requirements.txt b/requirements.txt index b7872c86..efba7e4a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -transformers==4.21.3 +transformers>=4.20.1 Flask Flask-SocketIO requests @@ -11,4 +11,4 @@ markdown bleach==4.1.0 sentencepiece protobuf -accelerate \ No newline at end of file +accelerate diff --git a/requirements_mtj.txt b/requirements_mtj.txt index d8476727..90c68634 100644 --- a/requirements_mtj.txt +++ b/requirements_mtj.txt @@ -6,7 +6,7 @@ optax >= 0.0.5, <= 0.0.9 dm-haiku == 0.0.5 jax == 0.2.21 jaxlib >= 0.1.69, <= 0.3.7 -transformers == 4.21.3 +transformers >= 4.20.1 progressbar2 git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck flask diff --git a/utils.py b/utils.py index 710f61d2..4dd67068 100644 --- a/utils.py +++ b/utils.py @@ -10,6 +10,8 @@ import time from tqdm.auto import tqdm import os import itertools +import hashlib +import huggingface_hub from typing import Optional vars = None @@ -159,7 +161,7 @@ def num_layers(config): #==================================================================# # Downloads huggingface checkpoints using aria2c if possible #==================================================================# -def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs): +def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs): import transformers import transformers.modeling_utils from huggingface_hub import HfFolder @@ -186,8 +188,8 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d headers["authorization"] = f"Bearer {use_auth_token}" def is_cached(url): try: - transformers.file_utils.get_from_cache(url, cache_dir=cache_dir, local_files_only=True) - except (FileNotFoundError, transformers.file_utils.EntryNotFoundError): + huggingface_hub.cached_download(url, cache_dir=cache_dir, local_files_only=True) + except ValueError: return False return True while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file @@ -195,7 +197,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME except AttributeError: return - url = transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, filename, revision=revision, mirror=mirror) + url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=revision) if is_cached(url) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers): break if sharded: @@ -205,18 +207,18 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download filenames = [transformers.modeling_utils.WEIGHTS_NAME] else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it - map_filename = transformers.file_utils.cached_path(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent) + map_filename = huggingface_hub.cached_download(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent) with open(map_filename) as f: map_data = json.load(f) filenames = set(map_data["weight_map"].values()) - urls = [transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, n, revision=revision, mirror=mirror) for n in filenames] + urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=revision) for n in filenames] if not force_download: urls = [u for u in urls if not is_cached(u)] if not urls: return etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]] headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls] - filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)] + filenames = [hashlib.sha256(u.encode("utf-8")).hexdigest() + "." + hashlib.sha256(t.encode("utf-8")).hexdigest() for u, t in zip(urls, etags)] for n in filenames: path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2") if os.path.exists(path): @@ -298,8 +300,8 @@ def get_num_shards(filename): # pytorch_model.bin.index.json, returns a list of weight names in the # sharded model. Requires lazy loader to be enabled to work properl #==================================================================# -def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs): +def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs): import transformers.modeling_utils import torch - shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror) + shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision) return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths))) From 7bf6c9a23f451b1d8fc61b9cdf916ca4c864a5b4 Mon Sep 17 00:00:00 2001 From: vfbd Date: Thu, 15 Sep 2022 13:47:48 -0400 Subject: [PATCH 2/4] Remove TPU Colab's dependency on optax and chex --- requirements_mtj.txt | 2 -- tpu_mtj_backend.py | 13 ++++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/requirements_mtj.txt b/requirements_mtj.txt index 90c68634..613e9203 100644 --- a/requirements_mtj.txt +++ b/requirements_mtj.txt @@ -2,7 +2,6 @@ torch >= 1.9, <= 1.11 numpy tqdm requests -optax >= 0.0.5, <= 0.0.9 dm-haiku == 0.0.5 jax == 0.2.21 jaxlib >= 0.1.69, <= 0.3.7 @@ -17,4 +16,3 @@ eventlet lupa==1.10 markdown bleach==4.1.0 -chex==0.1.4 diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index da0511df..0c6667a2 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -30,7 +30,7 @@ SOFTWARE. import utils import multiprocessing -from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, TypeVar import progressbar import time import os @@ -45,7 +45,6 @@ from jax.config import config from jax.experimental import maps import jax.numpy as jnp import numpy as np -import optax import haiku as hk from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM from tokenizers import Tokenizer @@ -120,6 +119,14 @@ def __batch_xmap(shard_dim=1): return inner +class _EmptyState(NamedTuple): + pass + +class _DummyOptimizer: + def init(*args, **kwargs): + return _EmptyState() + + def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange): ''' This gets called by generate_loop_fn to apply repetition penalty @@ -1148,7 +1155,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo cores_per_replica = params["cores_per_replica"] seq = params["seq"] - params["optimizer"] = optax.scale(0) + params["optimizer"] = _DummyOptimizer() mesh_shape = (1, cores_per_replica) devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape) thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ()) From 463bf86bcc2c3b0a4a8b21c1e8c628c61c451f76 Mon Sep 17 00:00:00 2001 From: vfbd Date: Thu, 15 Sep 2022 16:50:43 -0400 Subject: [PATCH 3/4] aria2_hook now uses new cache format if you have transformers 4.22 --- aiserver.py | 8 +- utils.py | 308 ++++++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 267 insertions(+), 49 deletions(-) diff --git a/aiserver.py b/aiserver.py index aedde57e..5860931b 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1713,11 +1713,13 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go import transformers.configuration_utils import transformers.modeling_utils import transformers.file_utils + import huggingface_hub + legacy = packaging.version.parse(transformers_version) < packaging.version.parse("4.22.0.dev0") # Save the config.json - shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, transformers.configuration_utils.CONFIG_NAME, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.configuration_utils.CONFIG_NAME)) + shutil.move(os.path.realpath(huggingface_hub.hf_hub_download(vars.model, transformers.configuration_utils.CONFIG_NAME, revision=vars.revision, cache_dir="cache", local_files_only=True, legacy_cache_layout=legacy)), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.configuration_utils.CONFIG_NAME)) if(utils.num_shards is None): # Save the pytorch_model.bin of an unsharded model - shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, transformers.modeling_utils.WEIGHTS_NAME, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.modeling_utils.WEIGHTS_NAME)) + shutil.move(os.path.realpath(huggingface_hub.hf_hub_download(vars.model, transformers.modeling_utils.WEIGHTS_NAME, revision=vars.revision, cache_dir="cache", local_files_only=True, legacy_cache_layout=legacy)), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.modeling_utils.WEIGHTS_NAME)) else: with open(utils.from_pretrained_index_filename) as f: map_data = json.load(f) @@ -1726,7 +1728,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go shutil.move(utils.from_pretrained_index_filename, os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.modeling_utils.WEIGHTS_INDEX_NAME)) # Then save the pytorch_model-#####-of-#####.bin files for filename in filenames: - shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, filename, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), filename)) + shutil.move(os.path.realpath(huggingface_hub.hf_hub_download(vars.model, filename, revision=vars.revision, cache_dir="cache", local_files_only=True, legacy_cache_layout=legacy)), os.path.join("models/{}".format(vars.model.replace('/', '_')), filename)) shutil.rmtree("cache/") if(vars.hascuda): diff --git a/utils.py b/utils.py index 4dd67068..d456ae32 100644 --- a/utils.py +++ b/utils.py @@ -4,6 +4,7 @@ import shutil import json import subprocess import tempfile +from urllib.error import HTTPError import requests import requests.adapters import time @@ -12,6 +13,8 @@ import os import itertools import hashlib import huggingface_hub +import packaging.version +from pathlib import Path from typing import Optional vars = None @@ -161,6 +164,262 @@ def num_layers(config): #==================================================================# # Downloads huggingface checkpoints using aria2c if possible #==================================================================# +def _download_with_aria2(aria2_config: str, total_length: int, directory: str = ".", user_agent=None, force_download=False, use_auth_token=None): + import transformers + lengths = {} + s = requests.Session() + s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1))) + bar = None + done = False + secret = os.urandom(17).hex() + try: + with tempfile.NamedTemporaryFile("w+b", delete=False) as f: + f.write(aria2_config) + f.flush() + p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(vars.aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", directory, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {use_auth_token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + while p.poll() is None: + r = s.post(f"http://localhost:{vars.aria2_port}/jsonrpc", json={"jsonrpc": "2.0", "id": "kai", "method": "aria2.tellActive", "params": [f"token:{secret}"]}).json()["result"] + if not r: + s.close() + if bar is not None: + bar.n = bar.total + bar.close() + p.terminate() + done = True + break + if bar is None: + bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000) + visited = set() + for x in r: + filename = x["files"][0]["path"] + lengths[filename] = (int(x["completedLength"]), int(x["totalLength"])) + visited.add(filename) + for k, v in lengths.items(): + if k not in visited: + lengths[k] = (v[1], v[1]) + bar.n = sum(v[0] for v in lengths.values()) + bar.update() + time.sleep(0.1) + path = f.name + except Exception as e: + p.terminate() + raise e + finally: + try: + os.remove(path) + except OSError: + pass + code = p.wait() + if not done and code: + raise OSError(f"aria2 exited with exit code {code}") + +def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs): + import transformers + import transformers.modeling_utils + from huggingface_hub import HfFolder + if use_auth_token: + if isinstance(use_auth_token, str): + token = use_auth_token + else: + token = HfFolder.get_token() + if token is None: + raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.") + _cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE + _revision = revision if revision is not None else huggingface_hub.constants.DEFAULT_REVISION + sharded = False + headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)} + if use_auth_token: + headers["authorization"] = f"Bearer {use_auth_token}" + + storage_folder = os.path.join(_cache_dir, huggingface_hub.file_download.repo_folder_name(repo_id=pretrained_model_name_or_path, repo_type="model")) + os.makedirs(storage_folder, exist_ok=True) + + def is_cached(filename): + try: + huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, local_files_only=True) + except ValueError: + return False + return True + while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file + try: + filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME + except AttributeError: + return + url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=revision) + if is_cached(filename) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers): + break + if sharded: + return + else: + sharded = True + if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download + filenames = [transformers.modeling_utils.WEIGHTS_NAME] + else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it + map_filename = huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent) + with open(map_filename) as f: + map_data = json.load(f) + filenames = set(map_data["weight_map"].values()) + urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=revision) for n in filenames] + if not force_download: + urls = [u for u, n in zip(urls, filenames) if not is_cached(n)] + if not urls: + return + + blob_paths = [] + + # This section is a modified version of hf_hub_download from huggingface_hub + # See https://github.com/huggingface/huggingface_hub/blob/main/LICENSE for license + for u, n in zip(urls, filenames): + relative_filename = os.path.join(*n.split("/")) + if not local_files_only: + try: + r = huggingface_hub.file_download._request_wrapper( + method="HEAD", + url=u, + headers=headers, + allow_redirects=False, + follow_relative_redirects=True, + proxies=proxies, + timeout=10, + ) + try: + r.raise_for_status() + except HTTPError as e: + error_code = r.headers.get("X-Error-Code") + if error_code != "EntryNotFound": + raise RuntimeError(f"HEAD {u} failed with error code {r.status_code}") + commit_hash = r.headers.get(huggingface_hub.file_download.HUGGINGFACE_HEADER_X_REPO_COMMIT) + if commit_hash is not None: + no_exist_file_path = ( + Path(storage_folder) + / ".no_exist" + / commit_hash + / relative_filename + ) + no_exist_file_path.parent.mkdir(parents=True, exist_ok=True) + no_exist_file_path.touch() + huggingface_hub.file_download._cache_commit_hash_for_specific_revision( + storage_folder, _revision, commit_hash + ) + raise + commit_hash = r.headers[huggingface_hub.file_download.HUGGINGFACE_HEADER_X_REPO_COMMIT] + if commit_hash is None: + raise OSError( + "Distant resource does not seem to be on huggingface.co (missing" + " commit header)." + ) + etag = r.headers.get(huggingface_hub.file_download.HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get( + "ETag" + ) + # We favor a custom header indicating the etag of the linked resource, and + # we fallback to the regular etag header. + # If we don't have any of those, raise an error. + if etag is None: + raise OSError( + "Distant resource does not have an ETag, we won't be able to" + " reliably ensure reproducibility." + ) + etag = huggingface_hub.file_download._normalize_etag(etag) + # In case of a redirect, save an extra redirect on the request.get call, + # and ensure we download the exact atomic version even if it changed + # between the HEAD and the GET (unlikely, but hey). + # Useful for lfs blobs that are stored on a CDN. + if 300 <= r.status_code <= 399: + url_to_download = r.headers["Location"] + if ( + "lfs.huggingface.co" in url_to_download + or "lfs-staging.huggingface.co" in url_to_download + ): + # Remove authorization header when downloading a LFS blob + headers.pop("authorization", None) + except (requests.exceptions.SSLError, requests.exceptions.ProxyError): + # Actually raise for those subclasses of ConnectionError + raise + except ( + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + huggingface_hub.file_download.OfflineModeIsEnabled, + ): + # Otherwise, our Internet connection is down. + # etag is None + pass + if etag is None: + # In those cases, we cannot force download. + if force_download: + raise ValueError( + "We have no connection or you passed local_files_only, so" + " force_download is not an accepted option." + ) + if huggingface_hub.file_download.REGEX_COMMIT_HASH.match(_revision): + commit_hash = _revision + else: + ref_path = os.path.join(storage_folder, "refs", _revision) + with open(ref_path) as f: + commit_hash = f.read() + pointer_path = os.path.join( + storage_folder, "snapshots", commit_hash, relative_filename + ) + if os.path.exists(pointer_path): + return pointer_path + # If we couldn't find an appropriate file on disk, + # raise an error. + # If files cannot be found and local_files_only=True, + # the models might've been found if local_files_only=False + # Notify the user about that + if local_files_only: + raise huggingface_hub.file_download.LocalEntryNotFoundError( + "Cannot find the requested files in the disk cache and" + " outgoing traffic has been disabled. To enable hf.co look-ups" + " and downloads online, set 'local_files_only' to False." + ) + else: + raise huggingface_hub.file_download.LocalEntryNotFoundError( + "Connection error, and we cannot find the requested files in" + " the disk cache. Please try again or make sure your Internet" + " connection is on." + ) + # From now on, etag and commit_hash are not None. + blob_path = os.path.join(storage_folder, "blobs", etag) + pointer_path = os.path.join( + storage_folder, "snapshots", commit_hash, relative_filename + ) + os.makedirs(os.path.dirname(blob_path), exist_ok=True) + os.makedirs(os.path.dirname(pointer_path), exist_ok=True) + # if passed revision is not identical to commit_hash + # then revision has to be a branch name or tag name. + # In that case store a ref. + huggingface_hub.file_download._cache_commit_hash_for_specific_revision(storage_folder, _revision, commit_hash) + if os.path.exists(pointer_path) and not force_download: + return pointer_path + if os.path.exists(blob_path) and not force_download: + # we have the blob already, but not the pointer + huggingface_hub.file_download.logger.info("creating pointer to %s from %s", blob_path, pointer_path) + huggingface_hub.file_download._create_relative_symlink(blob_path, pointer_path) + return pointer_path + # Some Windows versions do not allow for paths longer than 255 characters. + # In this case, we must specify it is an extended path by using the "\\?\" prefix. + if os.name == "nt" and len(os.path.abspath(blob_path)) > 255: + blob_path = "\\\\?\\" + os.path.abspath(blob_path) + blob_paths.append(blob_path) + + filenames = blob_paths + headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls] + + for n in filenames: + prefix, suffix = n.rsplit("/", 1) + path = os.path.join(prefix, "kai-tempfile." + suffix + ".aria2") + if os.path.exists(path): + os.remove(path) + path = os.path.join(prefix, "kai-tempfile." + suffix) + if os.path.exists(path): + os.remove(path) + total_length = sum(int(h["Content-Length"]) for h in headers) + aria2_config = "\n".join(f"{u}\n out={os.path.join(prefix, 'kai-tempfile.' + suffix)}" for u, n in zip(urls, filenames) for prefix, suffix in [n.rsplit("/", 1)]).encode() + _download_with_aria2(aria2_config, total_length, use_auth_token=token if use_auth_token else None, user_agent=user_agent, force_download=force_download) + for u, n in zip(urls, filenames): + prefix, suffix = n.rsplit("/", 1) + os.rename(os.path.join(prefix, "kai-tempfile." + suffix), os.path.join(prefix, suffix)) + def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs): import transformers import transformers.modeling_utils @@ -174,6 +433,8 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d if proxies: print("WARNING: KoboldAI does not support using aria2 to download models from huggingface.co through a proxy. Disabling aria2 download mode.") return + if packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.22.0.dev0"): + return _transformers22_aria2_hook(pretrained_model_name_or_path, force_download=force_download, cache_dir=cache_dir, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, **kwargs) if use_auth_token: if isinstance(use_auth_token, str): token = use_auth_token @@ -234,53 +495,8 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d if os.path.exists(path): os.remove(path) total_length = sum(int(h["Content-Length"]) for h in headers) - lengths = {} aria2_config = "\n".join(f"{u}\n out=kai-tempfile.{n}" for u, n in zip(urls, filenames)).encode() - s = requests.Session() - s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1))) - bar = None - done = False - secret = os.urandom(17).hex() - try: - with tempfile.NamedTemporaryFile("w+b", delete=False) as f: - f.write(aria2_config) - f.flush() - p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(vars.aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - while p.poll() is None: - r = s.post(f"http://localhost:{vars.aria2_port}/jsonrpc", json={"jsonrpc": "2.0", "id": "kai", "method": "aria2.tellActive", "params": [f"token:{secret}"]}).json()["result"] - if not r: - s.close() - if bar is not None: - bar.n = bar.total - bar.close() - p.terminate() - done = True - break - if bar is None: - bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000) - visited = set() - for x in r: - filename = x["files"][0]["path"] - lengths[filename] = (int(x["completedLength"]), int(x["totalLength"])) - visited.add(filename) - for k, v in lengths.items(): - if k not in visited: - lengths[k] = (v[1], v[1]) - bar.n = sum(v[0] for v in lengths.values()) - bar.update() - time.sleep(0.1) - path = f.name - except Exception as e: - p.terminate() - raise e - finally: - try: - os.remove(path) - except OSError: - pass - code = p.wait() - if not done and code: - raise OSError(f"aria2 exited with exit code {code}") + _download_with_aria2(aria2_config, total_length, directory=_cache_dir, use_auth_token=token if use_auth_token else None, user_agent=user_agent, force_download=force_download) for u, t, n in zip(urls, etags, filenames): os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n)) with open(os.path.join(_cache_dir, n + ".json"), "w") as f: From d55d8232d0ec659e7a7cb90e58cfe419d675ae7b Mon Sep 17 00:00:00 2001 From: vfbd Date: Thu, 15 Sep 2022 17:07:53 -0400 Subject: [PATCH 4/4] Unpin transformers version in Conda environments --- environments/huggingface.yml | 2 +- environments/rocm.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/environments/huggingface.yml b/environments/huggingface.yml index 62538529..205d5e31 100644 --- a/environments/huggingface.yml +++ b/environments/huggingface.yml @@ -20,5 +20,5 @@ dependencies: - flask-cloudflared - flask-ngrok - lupa==1.10 - - transformers==4.21.3 + - transformers>=4.20.1 - accelerate \ No newline at end of file diff --git a/environments/rocm.yml b/environments/rocm.yml index fb3336e9..8ade341f 100644 --- a/environments/rocm.yml +++ b/environments/rocm.yml @@ -20,5 +20,5 @@ dependencies: - flask-cloudflared - flask-ngrok - lupa==1.10 - - transformers==4.21.3 + - transformers>=4.20.1 - accelerate