From f5e689a725eb3ad450e80aa4a05d9890b84b5630 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Thu, 12 May 2022 19:09:31 -0400 Subject: [PATCH] Upload maps/opt.json and update requirements --- maps/opt.json | 32 ++++++++++++++++++++++++++++++++ requirements.txt | 2 +- requirements_mtj.txt | 4 ++-- tpu_mtj_backend.py | 2 ++ 4 files changed, 37 insertions(+), 3 deletions(-) create mode 100644 maps/opt.json diff --git a/maps/opt.json b/maps/opt.json new file mode 100644 index 00000000..b9667ac9 --- /dev/null +++ b/maps/opt.json @@ -0,0 +1,32 @@ +{ + "mtj_compat": "opt", + "mtj_pe": "fixed", + "mtj_config_map": { + "do_layer_norm_before": ["do_layer_norm_before", true], + "d_model": "hidden_size", + "n_heads": "num_attention_heads", + "layers": "num_hidden_layers" + }, + "static_weights": { + "decoder.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}}, + "decoder.embed_positions.weight": {"mtj": {"module": "embedding_shard", "param": "pos_embs", "transforms": ["no_transpose", "remove_first_two_rows"]}} + }, + "layer_weights": { + "decoder.layers.{layer}.self_attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}}, + "decoder.layers.{layer}.self_attn.q_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear", "param": "b"}}, + "decoder.layers.{layer}.self_attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}}, + "decoder.layers.{layer}.self_attn.v_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "b"}}, + "decoder.layers.{layer}.self_attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}}, + "decoder.layers.{layer}.self_attn.k_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "b"}}, + "decoder.layers.{layer}.self_attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}}, + "decoder.layers.{layer}.self_attn.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}}, + "decoder.layers.{layer}.fc1.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}}, + "decoder.layers.{layer}.fc1.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}}, + "decoder.layers.{layer}.fc2.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}}, + "decoder.layers.{layer}.fc2.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}}, + "decoder.layers.{layer}.self_attn_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}}, + "decoder.layers.{layer}.self_attn_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}}, + "decoder.layers.{layer}.final_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "scale"}}, + "decoder.layers.{layer}.final_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "offset"}} + } +} diff --git a/requirements.txt b/requirements.txt index 897f9e8e..7b5b967c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -transformers>=4.17 +transformers>=4.19 Flask Flask-SocketIO requests diff --git a/requirements_mtj.txt b/requirements_mtj.txt index 416a06a4..e2a6c4e1 100644 --- a/requirements_mtj.txt +++ b/requirements_mtj.txt @@ -5,9 +5,9 @@ requests optax >= 0.0.5, <= 0.0.9 dm-haiku == 0.0.5 jax == 0.2.21 -transformers >= 4.17 +transformers >= 4.19 progressbar2 -git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck +git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck-staging flask Flask-SocketIO flask-cloudflared >= 0.0.5 diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 1f67763f..75b4ee9c 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1200,6 +1200,8 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo # MTJ requires certain mathematical operations to be performed # on tensors in order for them to be in the correct format + if "remove_first_two_rows" in transforms: + tensor = tensor[2:] if "divide_by_shards" in transforms: tensor /= params["cores_per_replica"] if "vocab_pad" in transforms: