From a99eb8724d2c5e9c1f9ba9a3cb5a6bcc1a1ec066 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Thu, 10 Mar 2022 15:12:42 -0500 Subject: [PATCH] Use DLPack to convert PyTorch tensors to JAX arrays --- tpu_mtj_backend.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 03f4cdde..a565b578 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -38,6 +38,7 @@ import zipfile import requests import random import jax +import jax.dlpack from jax.config import config from jax.experimental import maps import jax.numpy as jnp @@ -990,18 +991,19 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo if "no_transpose" not in transforms: tensor = tensor.T tensor.unsqueeze_(0) + if tensor.dtype is torch.float16 or tensor.dtype is torch.float32: + tensor = tensor.bfloat16() # Shard the tensor so that parts of the tensor can be used # on different TPU cores network.state["params"][spec["module"]][spec["param"]] = move_xmap( - jnp.array( + jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack( reshard_reverse( tensor, params["cores_per_replica"], network.state["params"][spec["module"]][spec["param"]].shape, - ), - dtype=jnp.bfloat16, - ), + ) + )).copy(), np.empty(params["cores_per_replica"]), )