diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 03f4cdde..a565b578 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -38,6 +38,7 @@ import zipfile import requests import random import jax +import jax.dlpack from jax.config import config from jax.experimental import maps import jax.numpy as jnp @@ -990,18 +991,19 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo if "no_transpose" not in transforms: tensor = tensor.T tensor.unsqueeze_(0) + if tensor.dtype is torch.float16 or tensor.dtype is torch.float32: + tensor = tensor.bfloat16() # Shard the tensor so that parts of the tensor can be used # on different TPU cores network.state["params"][spec["module"]][spec["param"]] = move_xmap( - jnp.array( + jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack( reshard_reverse( tensor, params["cores_per_replica"], network.state["params"][spec["module"]][spec["param"]].shape, - ), - dtype=jnp.bfloat16, - ), + ) + )).copy(), np.empty(params["cores_per_replica"]), )