Allow TPU Colab to load sharded HF models
This commit is contained in:
parent
4fa5f1cd6a
commit
b1d8797a54
20
aiserver.py
20
aiserver.py
|
@ -1127,13 +1127,21 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||||
from transformers import __version__ as transformers_version
|
from transformers import __version__ as transformers_version
|
||||||
|
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
from transformers import modeling_utils
|
||||||
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||||
@classmethod
|
@classmethod
|
||||||
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
utils.num_shards = None
|
||||||
|
utils.current_shard = 0
|
||||||
if not args.no_aria2:
|
if not args.no_aria2:
|
||||||
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||||
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
PreTrainedModel.from_pretrained = new_from_pretrained
|
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||||
|
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||||
|
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||||
|
utils.num_shards = utils.get_num_shards(index_filename)
|
||||||
|
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
||||||
|
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
|
||||||
|
|
||||||
# Lazy loader
|
# Lazy loader
|
||||||
import torch_lazy_loader
|
import torch_lazy_loader
|
||||||
|
@ -1170,7 +1178,9 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||||
last_storage_key = None
|
last_storage_key = None
|
||||||
f = None
|
f = None
|
||||||
current_offset = 0
|
current_offset = 0
|
||||||
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
|
if utils.num_shards is not None:
|
||||||
|
utils.current_shard += 1
|
||||||
|
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")):
|
||||||
storage_key = model_dict[key].key
|
storage_key = model_dict[key].key
|
||||||
if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset:
|
if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset:
|
||||||
last_storage_key = storage_key
|
last_storage_key = storage_key
|
||||||
|
@ -1560,13 +1570,21 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
|
||||||
else:
|
else:
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
from transformers import modeling_utils
|
||||||
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||||
@classmethod
|
@classmethod
|
||||||
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
utils.num_shards = None
|
||||||
|
utils.current_shard = 0
|
||||||
if not args.no_aria2:
|
if not args.no_aria2:
|
||||||
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||||
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
PreTrainedModel.from_pretrained = new_from_pretrained
|
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||||
|
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||||
|
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||||
|
utils.num_shards = utils.get_num_shards(index_filename)
|
||||||
|
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
||||||
|
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
|
||||||
|
|
||||||
def tpumtjgetsofttokens():
|
def tpumtjgetsofttokens():
|
||||||
soft_tokens = None
|
soft_tokens = None
|
||||||
|
|
|
@ -7,7 +7,7 @@ dm-haiku == 0.0.5
|
||||||
jax == 0.2.21
|
jax == 0.2.21
|
||||||
transformers >= 4.19
|
transformers >= 4.19
|
||||||
progressbar2
|
progressbar2
|
||||||
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck-staging
|
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck
|
||||||
flask
|
flask
|
||||||
Flask-SocketIO
|
Flask-SocketIO
|
||||||
flask-cloudflared >= 0.0.5
|
flask-cloudflared >= 0.0.5
|
||||||
|
|
|
@ -27,6 +27,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
SOFTWARE.
|
SOFTWARE.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
import utils
|
||||||
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
|
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
|
||||||
import progressbar
|
import progressbar
|
||||||
|
@ -1163,8 +1165,11 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||||
last_storage_key = None
|
last_storage_key = None
|
||||||
f = None
|
f = None
|
||||||
current_offset = 0
|
current_offset = 0
|
||||||
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
|
if utils.current_shard == 0:
|
||||||
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
|
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
|
||||||
|
if utils.num_shards is not None:
|
||||||
|
utils.current_shard += 1
|
||||||
|
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")):
|
||||||
|
|
||||||
# Some model weights are used by transformers but not by MTJ.
|
# Some model weights are used by transformers but not by MTJ.
|
||||||
# We have to materialize these weights anyways because
|
# We have to materialize these weights anyways because
|
||||||
|
@ -1225,6 +1230,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||||
np.empty(params["cores_per_replica"]),
|
np.empty(params["cores_per_replica"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if utils.num_shards is not None and utils.current_shard < utils.num_shards:
|
||||||
|
return
|
||||||
|
|
||||||
# Check for tensors that MTJ needs that were not provided in the
|
# Check for tensors that MTJ needs that were not provided in the
|
||||||
# HF model
|
# HF model
|
||||||
for mk, mv in network.state["params"].items():
|
for mk, mv in network.state["params"].items():
|
||||||
|
|
14
utils.py
14
utils.py
|
@ -6,8 +6,11 @@ import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
import requests
|
import requests
|
||||||
import os
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
vars = None
|
vars = None
|
||||||
|
num_shards: Optional[int] = None
|
||||||
|
current_shard = 0
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Decorator to prevent a function's actions from being run until
|
# Decorator to prevent a function's actions from being run until
|
||||||
|
@ -133,7 +136,7 @@ def decodenewlines(txt):
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Downloads sharded huggingface checkpoints using aria2c if possible
|
# Downloads huggingface checkpoints using aria2c if possible
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
|
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
|
||||||
import transformers
|
import transformers
|
||||||
|
@ -225,3 +228,12 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
||||||
os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n))
|
os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n))
|
||||||
with open(os.path.join(_cache_dir, n + ".json"), "w") as f:
|
with open(os.path.join(_cache_dir, n + ".json"), "w") as f:
|
||||||
json.dump({"url": u, "etag": t}, f)
|
json.dump({"url": u, "etag": t}, f)
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Given the path to a pytorch_model.bin.index.json, returns how many
|
||||||
|
# shards there are in the model
|
||||||
|
#==================================================================#
|
||||||
|
def get_num_shards(filename):
|
||||||
|
with open(filename) as f:
|
||||||
|
map_data = json.load(f)
|
||||||
|
return len(set(map_data["weight_map"].values()))
|
||||||
|
|
Loading…
Reference in New Issue