From 4fa5f1cd6afb3486704870c2e56e84f5888e7f71 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Thu, 12 May 2022 22:21:15 -0400 Subject: [PATCH] Add TPU support for OPT-350M The 350M model seems to have a different structure than the other ones ??? --- aiserver.py | 8 ++++---- maps/opt.json | 5 ++++- tpu_mtj_backend.py | 2 +- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/aiserver.py b/aiserver.py index 0dc19c5a..1f105701 100644 --- a/aiserver.py +++ b/aiserver.py @@ -772,7 +772,7 @@ def spRequest(filename): tensor = tensor.reshape( tpu_mtj_backend.params["cores_per_replica"], -1, - tpu_mtj_backend.params["d_model"], + tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]), ) vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor)) else: @@ -1574,14 +1574,14 @@ else: global np if 'np' not in globals(): import numpy as np - tensor = np.zeros((1, tpu_mtj_backend.params["d_model"]), dtype=np.float32) + tensor = np.zeros((1, tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"])), dtype=np.float32) rows = tensor.shape[0] padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) tensor = tensor.reshape( tpu_mtj_backend.params["cores_per_replica"], -1, - tpu_mtj_backend.params["d_model"], + tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]), ) vars.sp = tpu_mtj_backend.shard_xmap(tensor) soft_tokens = np.arange( @@ -1682,7 +1682,7 @@ else: loadmodelsettings() loadsettings() tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and vars.use_colab_tpu, **vars.modelconfig) - vars.modeldim = int(tpu_mtj_backend.params["d_model"]) + vars.modeldim = int(tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"])) tokenizer = tpu_mtj_backend.tokenizer else: loadsettings() diff --git a/maps/opt.json b/maps/opt.json index b9667ac9..c99ae19f 100644 --- a/maps/opt.json +++ b/maps/opt.json @@ -3,13 +3,16 @@ "mtj_pe": "fixed", "mtj_config_map": { "do_layer_norm_before": ["do_layer_norm_before", true], + "d_embed": "word_embed_proj_dim", "d_model": "hidden_size", "n_heads": "num_attention_heads", "layers": "num_hidden_layers" }, "static_weights": { "decoder.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}}, - "decoder.embed_positions.weight": {"mtj": {"module": "embedding_shard", "param": "pos_embs", "transforms": ["no_transpose", "remove_first_two_rows"]}} + "decoder.project_in.weight": {"mtj": {"module": "embedding_shard", "param": "project_in"}}, + "decoder.embed_positions.weight": {"mtj": {"module": "embedding_shard", "param": "pos_embs", "transforms": ["no_transpose", "remove_first_two_rows"]}}, + "decoder.project_out.weight": {"mtj": {"module": "projection_shard", "param": "project_out"}} }, "layer_weights": { "decoder.layers.{layer}.self_attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}}, diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 75b4ee9c..b956648b 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1054,7 +1054,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo # by the number of TPU cores, and fall back to one core if an even # number of TPU cores is not possible. for c in (8, 6, 4, 2, 1): - if 0 == params["n_heads"] % c == params["d_model"] % c: + if 0 == params["n_heads"] % c == params.get("d_embed", params["d_model"]) % c: params["cores_per_replica"] = c break