Fix: TPU driver error

to_dlpack/from_dlpack was causing issues with tensor with new jax version
This commit is contained in:
Bogdan Drema 2023-04-23 00:49:42 +01:00
parent e4c15fe1f6
commit 92a0bf9524
2 changed files with 20 additions and 13 deletions

View File

@ -2,9 +2,10 @@ torch >= 1.9, < 1.13
numpy numpy
tqdm tqdm
requests requests
dm-haiku == 0.0.5 dm-haiku==0.0.9
jax == 0.2.21 jax==0.3.25
jaxlib >= 0.1.69, <= 0.3.7 jaxlib==0.3.25
transformers == 4.28.0
chex == 0.1.5 chex == 0.1.5
transformers == 4.24.0 transformers == 4.24.0
huggingface_hub==0.12.1 huggingface_hub==0.12.1

View File

@ -1049,7 +1049,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
raise RuntimeError(error) raise RuntimeError(error)
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None: def load_model(path: str, driver_version="tpu_driver_nightly", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None:
global thread_resources_env, seq, tokenizer, network, params, pad_token_id global thread_resources_env, seq, tokenizer, network, params, pad_token_id
if "pad_token_id" in kwargs: if "pad_token_id" in kwargs:
@ -1195,6 +1195,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ()) thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
maps.thread_resources.env = thread_resources_env maps.thread_resources.env = thread_resources_env
global shard_xmap, batch_xmap global shard_xmap, batch_xmap
shard_xmap = __shard_xmap() shard_xmap = __shard_xmap()
batch_xmap = __batch_xmap(shard_dim=cores_per_replica) batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
@ -1244,6 +1245,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
from tqdm.auto import tqdm from tqdm.auto import tqdm
import functools import functools
def callback(model_dict, f, **_): def callback(model_dict, f, **_):
if callback.nested: if callback.nested:
return return
@ -1317,19 +1319,20 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
#if "no_transpose" not in transforms and tensor.ndim == 2: #if "no_transpose" not in transforms and tensor.ndim == 2:
# tensor = tensor.T # tensor = tensor.T
tensor.unsqueeze_(0) tensor.unsqueeze_(0)
if tensor.dtype is torch.float16 or tensor.dtype is torch.float32:
tensor = tensor.bfloat16()
# Shard the tensor so that parts of the tensor can be used # Shard the tensor so that parts of the tensor can be used
# on different TPU cores # on different TPU cores
tensor = reshard_reverse(
tensor,
params["cores_per_replica"],
network.state["params"][spec["module"]][spec["param"]].shape,
)
tensor = jnp.array(tensor.detach())
if tensor.dtype is torch.float16 or tensor.dtype is torch.float32:
tensor = tensor.bfloat16()
network.state["params"][spec["module"]][spec["param"]] = move_xmap( network.state["params"][spec["module"]][spec["param"]] = move_xmap(
jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack( tensor,
reshard_reverse(
tensor,
params["cores_per_replica"],
network.state["params"][spec["module"]][spec["param"]].shape,
)
)).copy(),
np.empty(params["cores_per_replica"]), np.empty(params["cores_per_replica"]),
) )
@ -1416,3 +1419,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
model = GPTNeoForCausalLM.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache") model = GPTNeoForCausalLM.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
#network.state = network.move_xmap(network.state, np.zeros(cores_per_replica)) #network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
global shard_xmap, batch_xmap
shard_xmap = __shard_xmap()
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)