Use DLPack to convert PyTorch tensors to JAX arrays

This commit is contained in:
Gnome Ann 2022-03-10 15:12:42 -05:00
parent 68281184bf
commit a99eb8724d
1 changed files with 6 additions and 4 deletions

View File

@ -38,6 +38,7 @@ import zipfile
import requests
import random
import jax
import jax.dlpack
from jax.config import config
from jax.experimental import maps
import jax.numpy as jnp
@ -990,18 +991,19 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
if "no_transpose" not in transforms:
tensor = tensor.T
tensor.unsqueeze_(0)
if tensor.dtype is torch.float16 or tensor.dtype is torch.float32:
tensor = tensor.bfloat16()
# Shard the tensor so that parts of the tensor can be used
# on different TPU cores
network.state["params"][spec["module"]][spec["param"]] = move_xmap(
jnp.array(
jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(
reshard_reverse(
tensor,
params["cores_per_replica"],
network.state["params"][spec["module"]][spec["param"]].shape,
),
dtype=jnp.bfloat16,
),
)
)).copy(),
np.empty(params["cores_per_replica"]),
)