Allow TPU Colab to load sharded HF models

This commit is contained in:
Gnome Ann 2022-05-12 23:51:40 -04:00
parent 4fa5f1cd6a
commit b1d8797a54
4 changed files with 43 additions and 5 deletions

View File

@ -1127,13 +1127,21 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
from transformers import __version__ as transformers_version from transformers import __version__ as transformers_version
from transformers import PreTrainedModel from transformers import PreTrainedModel
from transformers import modeling_utils
old_from_pretrained = PreTrainedModel.from_pretrained.__func__ old_from_pretrained = PreTrainedModel.from_pretrained.__func__
@classmethod @classmethod
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
utils.num_shards = None
utils.current_shard = 0
if not args.no_aria2: if not args.no_aria2:
utils.aria2_hook(pretrained_model_name_or_path, **kwargs) utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
PreTrainedModel.from_pretrained = new_from_pretrained PreTrainedModel.from_pretrained = new_from_pretrained
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
utils.num_shards = utils.get_num_shards(index_filename)
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
# Lazy loader # Lazy loader
import torch_lazy_loader import torch_lazy_loader
@ -1170,7 +1178,9 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
last_storage_key = None last_storage_key = None
f = None f = None
current_offset = 0 current_offset = 0
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"): if utils.num_shards is not None:
utils.current_shard += 1
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")):
storage_key = model_dict[key].key storage_key = model_dict[key].key
if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset: if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset:
last_storage_key = storage_key last_storage_key = storage_key
@ -1560,13 +1570,21 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache") tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
else: else:
from transformers import PreTrainedModel from transformers import PreTrainedModel
from transformers import modeling_utils
old_from_pretrained = PreTrainedModel.from_pretrained.__func__ old_from_pretrained = PreTrainedModel.from_pretrained.__func__
@classmethod @classmethod
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
utils.num_shards = None
utils.current_shard = 0
if not args.no_aria2: if not args.no_aria2:
utils.aria2_hook(pretrained_model_name_or_path, **kwargs) utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
PreTrainedModel.from_pretrained = new_from_pretrained PreTrainedModel.from_pretrained = new_from_pretrained
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
utils.num_shards = utils.get_num_shards(index_filename)
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
def tpumtjgetsofttokens(): def tpumtjgetsofttokens():
soft_tokens = None soft_tokens = None

View File

@ -7,7 +7,7 @@ dm-haiku == 0.0.5
jax == 0.2.21 jax == 0.2.21
transformers >= 4.19 transformers >= 4.19
progressbar2 progressbar2
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck-staging git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck
flask flask
Flask-SocketIO Flask-SocketIO
flask-cloudflared >= 0.0.5 flask-cloudflared >= 0.0.5

View File

@ -27,6 +27,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
''' '''
import utils
import multiprocessing import multiprocessing
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
import progressbar import progressbar
@ -1163,8 +1165,11 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
last_storage_key = None last_storage_key = None
f = None f = None
current_offset = 0 current_offset = 0
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n") if utils.current_shard == 0:
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"): print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
if utils.num_shards is not None:
utils.current_shard += 1
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")):
# Some model weights are used by transformers but not by MTJ. # Some model weights are used by transformers but not by MTJ.
# We have to materialize these weights anyways because # We have to materialize these weights anyways because
@ -1225,6 +1230,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
np.empty(params["cores_per_replica"]), np.empty(params["cores_per_replica"]),
) )
if utils.num_shards is not None and utils.current_shard < utils.num_shards:
return
# Check for tensors that MTJ needs that were not provided in the # Check for tensors that MTJ needs that were not provided in the
# HF model # HF model
for mk, mv in network.state["params"].items(): for mk, mv in network.state["params"].items():

View File

@ -6,8 +6,11 @@ import subprocess
import tempfile import tempfile
import requests import requests
import os import os
from typing import Optional
vars = None vars = None
num_shards: Optional[int] = None
current_shard = 0
#==================================================================# #==================================================================#
# Decorator to prevent a function's actions from being run until # Decorator to prevent a function's actions from being run until
@ -133,7 +136,7 @@ def decodenewlines(txt):
return txt return txt
#==================================================================# #==================================================================#
# Downloads sharded huggingface checkpoints using aria2c if possible # Downloads huggingface checkpoints using aria2c if possible
#==================================================================# #==================================================================#
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs): def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
import transformers import transformers
@ -225,3 +228,12 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n)) os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n))
with open(os.path.join(_cache_dir, n + ".json"), "w") as f: with open(os.path.join(_cache_dir, n + ".json"), "w") as f:
json.dump({"url": u, "etag": t}, f) json.dump({"url": u, "etag": t}, f)
#==================================================================#
# Given the path to a pytorch_model.bin.index.json, returns how many
# shards there are in the model
#==================================================================#
def get_num_shards(filename):
with open(filename) as f:
map_data = json.load(f)
return len(set(map_data["weight_map"].values()))