Add TPU support for OPT-350M
The 350M model seems to have a different structure than the other ones ???
This commit is contained in:
parent
dfa2aa7314
commit
4fa5f1cd6a
|
@ -772,7 +772,7 @@ def spRequest(filename):
|
|||
tensor = tensor.reshape(
|
||||
tpu_mtj_backend.params["cores_per_replica"],
|
||||
-1,
|
||||
tpu_mtj_backend.params["d_model"],
|
||||
tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]),
|
||||
)
|
||||
vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
|
||||
else:
|
||||
|
@ -1574,14 +1574,14 @@ else:
|
|||
global np
|
||||
if 'np' not in globals():
|
||||
import numpy as np
|
||||
tensor = np.zeros((1, tpu_mtj_backend.params["d_model"]), dtype=np.float32)
|
||||
tensor = np.zeros((1, tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"])), dtype=np.float32)
|
||||
rows = tensor.shape[0]
|
||||
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
|
||||
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
|
||||
tensor = tensor.reshape(
|
||||
tpu_mtj_backend.params["cores_per_replica"],
|
||||
-1,
|
||||
tpu_mtj_backend.params["d_model"],
|
||||
tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]),
|
||||
)
|
||||
vars.sp = tpu_mtj_backend.shard_xmap(tensor)
|
||||
soft_tokens = np.arange(
|
||||
|
@ -1682,7 +1682,7 @@ else:
|
|||
loadmodelsettings()
|
||||
loadsettings()
|
||||
tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and vars.use_colab_tpu, **vars.modelconfig)
|
||||
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
||||
vars.modeldim = int(tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]))
|
||||
tokenizer = tpu_mtj_backend.tokenizer
|
||||
else:
|
||||
loadsettings()
|
||||
|
|
|
@ -3,13 +3,16 @@
|
|||
"mtj_pe": "fixed",
|
||||
"mtj_config_map": {
|
||||
"do_layer_norm_before": ["do_layer_norm_before", true],
|
||||
"d_embed": "word_embed_proj_dim",
|
||||
"d_model": "hidden_size",
|
||||
"n_heads": "num_attention_heads",
|
||||
"layers": "num_hidden_layers"
|
||||
},
|
||||
"static_weights": {
|
||||
"decoder.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}},
|
||||
"decoder.embed_positions.weight": {"mtj": {"module": "embedding_shard", "param": "pos_embs", "transforms": ["no_transpose", "remove_first_two_rows"]}}
|
||||
"decoder.project_in.weight": {"mtj": {"module": "embedding_shard", "param": "project_in"}},
|
||||
"decoder.embed_positions.weight": {"mtj": {"module": "embedding_shard", "param": "pos_embs", "transforms": ["no_transpose", "remove_first_two_rows"]}},
|
||||
"decoder.project_out.weight": {"mtj": {"module": "projection_shard", "param": "project_out"}}
|
||||
},
|
||||
"layer_weights": {
|
||||
"decoder.layers.{layer}.self_attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}},
|
||||
|
|
|
@ -1054,7 +1054,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||
# by the number of TPU cores, and fall back to one core if an even
|
||||
# number of TPU cores is not possible.
|
||||
for c in (8, 6, 4, 2, 1):
|
||||
if 0 == params["n_heads"] % c == params["d_model"] % c:
|
||||
if 0 == params["n_heads"] % c == params.get("d_embed", params["d_model"]) % c:
|
||||
params["cores_per_replica"] = c
|
||||
break
|
||||
|
||||
|
|
Loading…
Reference in New Issue