xmaps for moving things onto TPU

This commit is contained in:
Gnome Ann
2022-01-12 21:45:30 -05:00
parent 49e2bcab1a
commit a3d6dc93e8
2 changed files with 41 additions and 11 deletions

View File

@ -993,7 +993,7 @@ else:
-1,
tpu_mtj_backend.params["d_model"],
)
vars.sp = tensor
vars.sp = tpu_mtj_backend.shard_xmap(tensor)
soft_tokens = np.arange(
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
@ -4005,7 +4005,7 @@ def spRequest(filename):
-1,
tpu_mtj_backend.params["d_model"],
)
vars.sp = np.float32(tensor)
vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
else:
vars.sp = torch.from_numpy(tensor)