Add TPU support for OPT-350M

The 350M model seems to have a different structure than the other ones ???
This commit is contained in:
Gnome Ann 2022-05-12 22:21:15 -04:00
parent dfa2aa7314
commit 4fa5f1cd6a
3 changed files with 9 additions and 6 deletions

View File

@ -772,7 +772,7 @@ def spRequest(filename):
tensor = tensor.reshape(
tpu_mtj_backend.params["cores_per_replica"],
-1,
tpu_mtj_backend.params["d_model"],
tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]),
)
vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
else:
@ -1574,14 +1574,14 @@ else:
global np
if 'np' not in globals():
import numpy as np
tensor = np.zeros((1, tpu_mtj_backend.params["d_model"]), dtype=np.float32)
tensor = np.zeros((1, tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"])), dtype=np.float32)
rows = tensor.shape[0]
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
tensor = tensor.reshape(
tpu_mtj_backend.params["cores_per_replica"],
-1,
tpu_mtj_backend.params["d_model"],
tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]),
)
vars.sp = tpu_mtj_backend.shard_xmap(tensor)
soft_tokens = np.arange(
@ -1682,7 +1682,7 @@ else:
loadmodelsettings()
loadsettings()
tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and vars.use_colab_tpu, **vars.modelconfig)
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
vars.modeldim = int(tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]))
tokenizer = tpu_mtj_backend.tokenizer
else:
loadsettings()

View File

@ -3,13 +3,16 @@
"mtj_pe": "fixed",
"mtj_config_map": {
"do_layer_norm_before": ["do_layer_norm_before", true],
"d_embed": "word_embed_proj_dim",
"d_model": "hidden_size",
"n_heads": "num_attention_heads",
"layers": "num_hidden_layers"
},
"static_weights": {
"decoder.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}},
"decoder.embed_positions.weight": {"mtj": {"module": "embedding_shard", "param": "pos_embs", "transforms": ["no_transpose", "remove_first_two_rows"]}}
"decoder.project_in.weight": {"mtj": {"module": "embedding_shard", "param": "project_in"}},
"decoder.embed_positions.weight": {"mtj": {"module": "embedding_shard", "param": "pos_embs", "transforms": ["no_transpose", "remove_first_two_rows"]}},
"decoder.project_out.weight": {"mtj": {"module": "projection_shard", "param": "project_out"}}
},
"layer_weights": {
"decoder.layers.{layer}.self_attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}},

View File

@ -1054,7 +1054,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
# by the number of TPU cores, and fall back to one core if an even
# number of TPU cores is not possible.
for c in (8, 6, 4, 2, 1):
if 0 == params["n_heads"] % c == params["d_model"] % c:
if 0 == params["n_heads"] % c == params.get("d_embed", params["d_model"]) % c:
params["cores_per_replica"] = c
break