diff --git a/requirements.txt b/requirements.txt index b7872c86..efba7e4a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -transformers==4.21.3 +transformers>=4.20.1 Flask Flask-SocketIO requests @@ -11,4 +11,4 @@ markdown bleach==4.1.0 sentencepiece protobuf -accelerate \ No newline at end of file +accelerate diff --git a/requirements_mtj.txt b/requirements_mtj.txt index d8476727..90c68634 100644 --- a/requirements_mtj.txt +++ b/requirements_mtj.txt @@ -6,7 +6,7 @@ optax >= 0.0.5, <= 0.0.9 dm-haiku == 0.0.5 jax == 0.2.21 jaxlib >= 0.1.69, <= 0.3.7 -transformers == 4.21.3 +transformers >= 4.20.1 progressbar2 git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck flask diff --git a/utils.py b/utils.py index 710f61d2..4dd67068 100644 --- a/utils.py +++ b/utils.py @@ -10,6 +10,8 @@ import time from tqdm.auto import tqdm import os import itertools +import hashlib +import huggingface_hub from typing import Optional vars = None @@ -159,7 +161,7 @@ def num_layers(config): #==================================================================# # Downloads huggingface checkpoints using aria2c if possible #==================================================================# -def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs): +def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs): import transformers import transformers.modeling_utils from huggingface_hub import HfFolder @@ -186,8 +188,8 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d headers["authorization"] = f"Bearer {use_auth_token}" def is_cached(url): try: - transformers.file_utils.get_from_cache(url, cache_dir=cache_dir, local_files_only=True) - except (FileNotFoundError, transformers.file_utils.EntryNotFoundError): + huggingface_hub.cached_download(url, cache_dir=cache_dir, local_files_only=True) + except ValueError: return False return True while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file @@ -195,7 +197,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME except AttributeError: return - url = transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, filename, revision=revision, mirror=mirror) + url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=revision) if is_cached(url) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers): break if sharded: @@ -205,18 +207,18 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download filenames = [transformers.modeling_utils.WEIGHTS_NAME] else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it - map_filename = transformers.file_utils.cached_path(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent) + map_filename = huggingface_hub.cached_download(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent) with open(map_filename) as f: map_data = json.load(f) filenames = set(map_data["weight_map"].values()) - urls = [transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, n, revision=revision, mirror=mirror) for n in filenames] + urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=revision) for n in filenames] if not force_download: urls = [u for u in urls if not is_cached(u)] if not urls: return etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]] headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls] - filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)] + filenames = [hashlib.sha256(u.encode("utf-8")).hexdigest() + "." + hashlib.sha256(t.encode("utf-8")).hexdigest() for u, t in zip(urls, etags)] for n in filenames: path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2") if os.path.exists(path): @@ -298,8 +300,8 @@ def get_num_shards(filename): # pytorch_model.bin.index.json, returns a list of weight names in the # sharded model. Requires lazy loader to be enabled to work properl #==================================================================# -def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs): +def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs): import transformers.modeling_utils import torch - shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror) + shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision) return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))