Use DLPack to convert PyTorch tensors to JAX arrays

This commit is contained in:
Gnome Ann 2022-03-10 15:12:42 -05:00
parent 68281184bf
commit a99eb8724d

View File

@ -38,6 +38,7 @@ import zipfile
import requests import requests
import random import random
import jax import jax
import jax.dlpack
from jax.config import config from jax.config import config
from jax.experimental import maps from jax.experimental import maps
import jax.numpy as jnp import jax.numpy as jnp
@ -990,18 +991,19 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
if "no_transpose" not in transforms: if "no_transpose" not in transforms:
tensor = tensor.T tensor = tensor.T
tensor.unsqueeze_(0) tensor.unsqueeze_(0)
if tensor.dtype is torch.float16 or tensor.dtype is torch.float32:
tensor = tensor.bfloat16()
# Shard the tensor so that parts of the tensor can be used # Shard the tensor so that parts of the tensor can be used
# on different TPU cores # on different TPU cores
network.state["params"][spec["module"]][spec["param"]] = move_xmap( network.state["params"][spec["module"]][spec["param"]] = move_xmap(
jnp.array( jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(
reshard_reverse( reshard_reverse(
tensor, tensor,
params["cores_per_replica"], params["cores_per_replica"],
network.state["params"][spec["module"]][spec["param"]].shape, network.state["params"][spec["module"]][spec["param"]].shape,
), )
dtype=jnp.bfloat16, )).copy(),
),
np.empty(params["cores_per_replica"]), np.empty(params["cores_per_replica"]),
) )