Should divide NeoX replicated parameters by 8 (not by 4)

Also, suppresses the PyTorch 1.11 warning about transposing tensors with
ndim != 2 in the new code
This commit is contained in:
Gnome Ann 2022-03-19 00:48:33 -04:00
parent c2c139e940
commit f16b61ec77
1 changed files with 2 additions and 2 deletions

View File

@ -885,8 +885,8 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
original_shape = shards[0][key].shape
for checkpoint_shard in range(checkpoint_shards):
if key in ("attention.dense.bias", "mlp.dense_4h_to_h.bias"):
shards[checkpoint_shard][key] /= output_shards
if key != "word_embeddings.weight":
shards[checkpoint_shard][key] /= config["cores_per_replica"]
if key != "word_embeddings.weight" and shards[checkpoint_shard][key].ndim == 2:
shards[checkpoint_shard][key] = shards[checkpoint_shard][key].T
tensor = shards[checkpoint_shard][key]
if target_axis is not None: