From 05fc46b2536ce47ec6c18092aba600a046886a56 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Sat, 19 Mar 2022 02:09:41 -0400 Subject: [PATCH] Changing this again to divide by 8 --- tpu_mtj_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 364d39d5..31f0485d 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -885,7 +885,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2): original_shape = shards[0][key].shape for checkpoint_shard in range(checkpoint_shards): if key in ("attention.dense.bias", "mlp.dense_4h_to_h.bias"): - shards[checkpoint_shard][key] /= output_shards + shards[checkpoint_shard][key] /= config["cores_per_replica"] if key != "word_embeddings.weight" and shards[checkpoint_shard][key].ndim == 2: shards[checkpoint_shard][key] = shards[checkpoint_shard][key].T tensor = shards[checkpoint_shard][key]