Fix error in aria2_hook when transformers version is at least 4.22.0
Some of the transformers.file_utils functions that were removed in transformers v4.22.0 have equivalents in the huggingface_hub module.
This commit is contained in:
parent
aac999c073
commit
551565c5ac
|
@ -1,4 +1,4 @@
|
||||||
transformers==4.21.3
|
transformers>=4.20.1
|
||||||
Flask
|
Flask
|
||||||
Flask-SocketIO
|
Flask-SocketIO
|
||||||
requests
|
requests
|
||||||
|
|
|
@ -6,7 +6,7 @@ optax >= 0.0.5, <= 0.0.9
|
||||||
dm-haiku == 0.0.5
|
dm-haiku == 0.0.5
|
||||||
jax == 0.2.21
|
jax == 0.2.21
|
||||||
jaxlib >= 0.1.69, <= 0.3.7
|
jaxlib >= 0.1.69, <= 0.3.7
|
||||||
transformers == 4.21.3
|
transformers >= 4.20.1
|
||||||
progressbar2
|
progressbar2
|
||||||
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck
|
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck
|
||||||
flask
|
flask
|
||||||
|
|
20
utils.py
20
utils.py
|
@ -10,6 +10,8 @@ import time
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
import os
|
import os
|
||||||
import itertools
|
import itertools
|
||||||
|
import hashlib
|
||||||
|
import huggingface_hub
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
vars = None
|
vars = None
|
||||||
|
@ -159,7 +161,7 @@ def num_layers(config):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Downloads huggingface checkpoints using aria2c if possible
|
# Downloads huggingface checkpoints using aria2c if possible
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
|
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
|
||||||
import transformers
|
import transformers
|
||||||
import transformers.modeling_utils
|
import transformers.modeling_utils
|
||||||
from huggingface_hub import HfFolder
|
from huggingface_hub import HfFolder
|
||||||
|
@ -186,8 +188,8 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
||||||
headers["authorization"] = f"Bearer {use_auth_token}"
|
headers["authorization"] = f"Bearer {use_auth_token}"
|
||||||
def is_cached(url):
|
def is_cached(url):
|
||||||
try:
|
try:
|
||||||
transformers.file_utils.get_from_cache(url, cache_dir=cache_dir, local_files_only=True)
|
huggingface_hub.cached_download(url, cache_dir=cache_dir, local_files_only=True)
|
||||||
except (FileNotFoundError, transformers.file_utils.EntryNotFoundError):
|
except ValueError:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file
|
while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file
|
||||||
|
@ -195,7 +197,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
||||||
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
|
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
return
|
return
|
||||||
url = transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, filename, revision=revision, mirror=mirror)
|
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=revision)
|
||||||
if is_cached(url) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
|
if is_cached(url) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
|
||||||
break
|
break
|
||||||
if sharded:
|
if sharded:
|
||||||
|
@ -205,18 +207,18 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
||||||
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
|
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
|
||||||
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
|
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
|
||||||
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
|
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
|
||||||
map_filename = transformers.file_utils.cached_path(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent)
|
map_filename = huggingface_hub.cached_download(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent)
|
||||||
with open(map_filename) as f:
|
with open(map_filename) as f:
|
||||||
map_data = json.load(f)
|
map_data = json.load(f)
|
||||||
filenames = set(map_data["weight_map"].values())
|
filenames = set(map_data["weight_map"].values())
|
||||||
urls = [transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, n, revision=revision, mirror=mirror) for n in filenames]
|
urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=revision) for n in filenames]
|
||||||
if not force_download:
|
if not force_download:
|
||||||
urls = [u for u in urls if not is_cached(u)]
|
urls = [u for u in urls if not is_cached(u)]
|
||||||
if not urls:
|
if not urls:
|
||||||
return
|
return
|
||||||
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
|
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
|
||||||
headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
|
headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
|
||||||
filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)]
|
filenames = [hashlib.sha256(u.encode("utf-8")).hexdigest() + "." + hashlib.sha256(t.encode("utf-8")).hexdigest() for u, t in zip(urls, etags)]
|
||||||
for n in filenames:
|
for n in filenames:
|
||||||
path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2")
|
path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2")
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
|
@ -298,8 +300,8 @@ def get_num_shards(filename):
|
||||||
# pytorch_model.bin.index.json, returns a list of weight names in the
|
# pytorch_model.bin.index.json, returns a list of weight names in the
|
||||||
# sharded model. Requires lazy loader to be enabled to work properl
|
# sharded model. Requires lazy loader to be enabled to work properl
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
|
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
|
||||||
import transformers.modeling_utils
|
import transformers.modeling_utils
|
||||||
import torch
|
import torch
|
||||||
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror)
|
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision)
|
||||||
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))
|
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))
|
||||||
|
|
Loading…
Reference in New Issue