Merge pull request #96 from VE-FORBRYDERNE/dlpack

Use DLPack to convert PyTorch tensors to JAX arrays
This commit is contained in:
henk717 2022-03-10 22:00:38 +01:00 committed by GitHub
commit 2c66461c14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 4 deletions

View File

@ -38,6 +38,7 @@ import zipfile
import requests
import random
import jax
import jax.dlpack
from jax.config import config
from jax.experimental import maps
import jax.numpy as jnp
@ -990,18 +991,19 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
if "no_transpose" not in transforms:
tensor = tensor.T
tensor.unsqueeze_(0)
if tensor.dtype is torch.float16 or tensor.dtype is torch.float32:
tensor = tensor.bfloat16()
# Shard the tensor so that parts of the tensor can be used
# on different TPU cores
network.state["params"][spec["module"]][spec["param"]] = move_xmap(
jnp.array(
jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(
reshard_reverse(
tensor,
params["cores_per_replica"],
network.state["params"][spec["module"]][spec["param"]].shape,
),
dtype=jnp.bfloat16,
),
)
)).copy(),
np.empty(params["cores_per_replica"]),
)