diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 31f0485d..364d39d5 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -885,7 +885,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2): original_shape = shards[0][key].shape for checkpoint_shard in range(checkpoint_shards): if key in ("attention.dense.bias", "mlp.dense_4h_to_h.bias"): - shards[checkpoint_shard][key] /= config["cores_per_replica"] + shards[checkpoint_shard][key] /= output_shards if key != "word_embeddings.weight" and shards[checkpoint_shard][key].ndim == 2: shards[checkpoint_shard][key] = shards[checkpoint_shard][key].T tensor = shards[checkpoint_shard][key]