From b1d8797a54d4e22b6d43062930ed63765586bdc5 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Thu, 12 May 2022 23:51:40 -0400 Subject: [PATCH] Allow TPU Colab to load sharded HF models --- aiserver.py | 20 +++++++++++++++++++- requirements_mtj.txt | 2 +- tpu_mtj_backend.py | 12 ++++++++++-- utils.py | 14 +++++++++++++- 4 files changed, 43 insertions(+), 5 deletions(-) diff --git a/aiserver.py b/aiserver.py index 1f105701..6c9401e2 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1127,13 +1127,21 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go from transformers import __version__ as transformers_version from transformers import PreTrainedModel + from transformers import modeling_utils old_from_pretrained = PreTrainedModel.from_pretrained.__func__ @classmethod def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + utils.num_shards = None + utils.current_shard = 0 if not args.no_aria2: utils.aria2_hook(pretrained_model_name_or_path, **kwargs) return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) PreTrainedModel.from_pretrained = new_from_pretrained + old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files + def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs): + utils.num_shards = utils.get_num_shards(index_filename) + return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs) + modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files # Lazy loader import torch_lazy_loader @@ -1170,7 +1178,9 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go last_storage_key = None f = None current_offset = 0 - for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"): + if utils.num_shards is not None: + utils.current_shard += 1 + for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")): storage_key = model_dict[key].key if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset: last_storage_key = storage_key @@ -1560,13 +1570,21 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache") else: from transformers import PreTrainedModel + from transformers import modeling_utils old_from_pretrained = PreTrainedModel.from_pretrained.__func__ @classmethod def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + utils.num_shards = None + utils.current_shard = 0 if not args.no_aria2: utils.aria2_hook(pretrained_model_name_or_path, **kwargs) return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) PreTrainedModel.from_pretrained = new_from_pretrained + old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files + def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs): + utils.num_shards = utils.get_num_shards(index_filename) + return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs) + modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files def tpumtjgetsofttokens(): soft_tokens = None diff --git a/requirements_mtj.txt b/requirements_mtj.txt index e2a6c4e1..0f723a49 100644 --- a/requirements_mtj.txt +++ b/requirements_mtj.txt @@ -7,7 +7,7 @@ dm-haiku == 0.0.5 jax == 0.2.21 transformers >= 4.19 progressbar2 -git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck-staging +git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck flask Flask-SocketIO flask-cloudflared >= 0.0.5 diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index b956648b..2fa149d7 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -27,6 +27,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ''' +import utils + import multiprocessing from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar import progressbar @@ -1163,8 +1165,11 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo last_storage_key = None f = None current_offset = 0 - print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n") - for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"): + if utils.current_shard == 0: + print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n") + if utils.num_shards is not None: + utils.current_shard += 1 + for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")): # Some model weights are used by transformers but not by MTJ. # We have to materialize these weights anyways because @@ -1225,6 +1230,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo np.empty(params["cores_per_replica"]), ) + if utils.num_shards is not None and utils.current_shard < utils.num_shards: + return + # Check for tensors that MTJ needs that were not provided in the # HF model for mk, mv in network.state["params"].items(): diff --git a/utils.py b/utils.py index c6eb85ec..0fdfa125 100644 --- a/utils.py +++ b/utils.py @@ -6,8 +6,11 @@ import subprocess import tempfile import requests import os +from typing import Optional vars = None +num_shards: Optional[int] = None +current_shard = 0 #==================================================================# # Decorator to prevent a function's actions from being run until @@ -133,7 +136,7 @@ def decodenewlines(txt): return txt #==================================================================# -# Downloads sharded huggingface checkpoints using aria2c if possible +# Downloads huggingface checkpoints using aria2c if possible #==================================================================# def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs): import transformers @@ -225,3 +228,12 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n)) with open(os.path.join(_cache_dir, n + ".json"), "w") as f: json.dump({"url": u, "etag": t}, f) + +#==================================================================# +# Given the path to a pytorch_model.bin.index.json, returns how many +# shards there are in the model +#==================================================================# +def get_num_shards(filename): + with open(filename) as f: + map_data = json.load(f) + return len(set(map_data["weight_map"].values()))