Allow TPU Colab to load sharded HF models

This commit is contained in:
Gnome Ann
2022-05-12 23:51:40 -04:00
parent 4fa5f1cd6a
commit b1d8797a54
4 changed files with 43 additions and 5 deletions

View File

@ -27,6 +27,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
'''
import utils
import multiprocessing
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
import progressbar
@ -1163,8 +1165,11 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
last_storage_key = None
f = None
current_offset = 0
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
if utils.current_shard == 0:
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
if utils.num_shards is not None:
utils.current_shard += 1
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")):
# Some model weights are used by transformers but not by MTJ.
# We have to materialize these weights anyways because
@ -1225,6 +1230,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
np.empty(params["cores_per_replica"]),
)
if utils.num_shards is not None and utils.current_shard < utils.num_shards:
return
# Check for tensors that MTJ needs that were not provided in the
# HF model
for mk, mv in network.state["params"].items():