Merge pull request #96 from VE-FORBRYDERNE/dlpack

Use DLPack to convert PyTorch tensors to JAX arrays
This commit is contained in:
henk717 2022-03-10 22:00:38 +01:00 committed by GitHub
commit 2c66461c14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 4 deletions

View File

@ -38,6 +38,7 @@ import zipfile
import requests import requests
import random import random
import jax import jax
import jax.dlpack
from jax.config import config from jax.config import config
from jax.experimental import maps from jax.experimental import maps
import jax.numpy as jnp import jax.numpy as jnp
@ -990,18 +991,19 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
if "no_transpose" not in transforms: if "no_transpose" not in transforms:
tensor = tensor.T tensor = tensor.T
tensor.unsqueeze_(0) tensor.unsqueeze_(0)
if tensor.dtype is torch.float16 or tensor.dtype is torch.float32:
tensor = tensor.bfloat16()
# Shard the tensor so that parts of the tensor can be used # Shard the tensor so that parts of the tensor can be used
# on different TPU cores # on different TPU cores
network.state["params"][spec["module"]][spec["param"]] = move_xmap( network.state["params"][spec["module"]][spec["param"]] = move_xmap(
jnp.array( jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(
reshard_reverse( reshard_reverse(
tensor, tensor,
params["cores_per_replica"], params["cores_per_replica"],
network.state["params"][spec["module"]][spec["param"]].shape, network.state["params"][spec["module"]][spec["param"]].shape,
), )
dtype=jnp.bfloat16, )).copy(),
),
np.empty(params["cores_per_replica"]), np.empty(params["cores_per_replica"]),
) )