Allow TPU Colab to load sharded HF models
This commit is contained in:
parent
4fa5f1cd6a
commit
b1d8797a54
20
aiserver.py
20
aiserver.py
|
@ -1127,13 +1127,21 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||
from transformers import __version__ as transformers_version
|
||||
|
||||
from transformers import PreTrainedModel
|
||||
from transformers import modeling_utils
|
||||
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||
@classmethod
|
||||
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
utils.num_shards = None
|
||||
utils.current_shard = 0
|
||||
if not args.no_aria2:
|
||||
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||
utils.num_shards = utils.get_num_shards(index_filename)
|
||||
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
||||
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
|
||||
|
||||
# Lazy loader
|
||||
import torch_lazy_loader
|
||||
|
@ -1170,7 +1178,9 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||
last_storage_key = None
|
||||
f = None
|
||||
current_offset = 0
|
||||
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
|
||||
if utils.num_shards is not None:
|
||||
utils.current_shard += 1
|
||||
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")):
|
||||
storage_key = model_dict[key].key
|
||||
if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset:
|
||||
last_storage_key = storage_key
|
||||
|
@ -1560,13 +1570,21 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
|
||||
else:
|
||||
from transformers import PreTrainedModel
|
||||
from transformers import modeling_utils
|
||||
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||
@classmethod
|
||||
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
utils.num_shards = None
|
||||
utils.current_shard = 0
|
||||
if not args.no_aria2:
|
||||
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||
utils.num_shards = utils.get_num_shards(index_filename)
|
||||
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
||||
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
|
||||
|
||||
def tpumtjgetsofttokens():
|
||||
soft_tokens = None
|
||||
|
|
|
@ -7,7 +7,7 @@ dm-haiku == 0.0.5
|
|||
jax == 0.2.21
|
||||
transformers >= 4.19
|
||||
progressbar2
|
||||
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck-staging
|
||||
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck
|
||||
flask
|
||||
Flask-SocketIO
|
||||
flask-cloudflared >= 0.0.5
|
||||
|
|
|
@ -27,6 +27,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|||
SOFTWARE.
|
||||
'''
|
||||
|
||||
import utils
|
||||
|
||||
import multiprocessing
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
|
||||
import progressbar
|
||||
|
@ -1163,8 +1165,11 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||
last_storage_key = None
|
||||
f = None
|
||||
current_offset = 0
|
||||
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
|
||||
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
|
||||
if utils.current_shard == 0:
|
||||
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
|
||||
if utils.num_shards is not None:
|
||||
utils.current_shard += 1
|
||||
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")):
|
||||
|
||||
# Some model weights are used by transformers but not by MTJ.
|
||||
# We have to materialize these weights anyways because
|
||||
|
@ -1225,6 +1230,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||
np.empty(params["cores_per_replica"]),
|
||||
)
|
||||
|
||||
if utils.num_shards is not None and utils.current_shard < utils.num_shards:
|
||||
return
|
||||
|
||||
# Check for tensors that MTJ needs that were not provided in the
|
||||
# HF model
|
||||
for mk, mv in network.state["params"].items():
|
||||
|
|
14
utils.py
14
utils.py
|
@ -6,8 +6,11 @@ import subprocess
|
|||
import tempfile
|
||||
import requests
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
vars = None
|
||||
num_shards: Optional[int] = None
|
||||
current_shard = 0
|
||||
|
||||
#==================================================================#
|
||||
# Decorator to prevent a function's actions from being run until
|
||||
|
@ -133,7 +136,7 @@ def decodenewlines(txt):
|
|||
return txt
|
||||
|
||||
#==================================================================#
|
||||
# Downloads sharded huggingface checkpoints using aria2c if possible
|
||||
# Downloads huggingface checkpoints using aria2c if possible
|
||||
#==================================================================#
|
||||
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
|
||||
import transformers
|
||||
|
@ -225,3 +228,12 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
|||
os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n))
|
||||
with open(os.path.join(_cache_dir, n + ".json"), "w") as f:
|
||||
json.dump({"url": u, "etag": t}, f)
|
||||
|
||||
#==================================================================#
|
||||
# Given the path to a pytorch_model.bin.index.json, returns how many
|
||||
# shards there are in the model
|
||||
#==================================================================#
|
||||
def get_num_shards(filename):
|
||||
with open(filename) as f:
|
||||
map_data = json.load(f)
|
||||
return len(set(map_data["weight_map"].values()))
|
||||
|
|
Loading…
Reference in New Issue