From 92a0bf9524f831d10b8b1006ee383a300fef5af8 Mon Sep 17 00:00:00 2001 From: Bogdan Drema Date: Sun, 23 Apr 2023 00:49:42 +0100 Subject: [PATCH] Fix: TPU driver error to_dlpack/from_dlpack was causing issues with tensor with new jax version --- requirements_mtj.txt | 7 ++++--- tpu_mtj_backend.py | 26 ++++++++++++++++---------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/requirements_mtj.txt b/requirements_mtj.txt index 9447541f..ea399a8f 100644 --- a/requirements_mtj.txt +++ b/requirements_mtj.txt @@ -2,9 +2,10 @@ torch >= 1.9, < 1.13 numpy tqdm requests -dm-haiku == 0.0.5 -jax == 0.2.21 -jaxlib >= 0.1.69, <= 0.3.7 +dm-haiku==0.0.9 +jax==0.3.25 +jaxlib==0.3.25 +transformers == 4.28.0 chex == 0.1.5 transformers == 4.24.0 huggingface_hub==0.12.1 diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 4b27493e..f878d690 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1049,7 +1049,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2): raise RuntimeError(error) -def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None: +def load_model(path: str, driver_version="tpu_driver_nightly", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None: global thread_resources_env, seq, tokenizer, network, params, pad_token_id if "pad_token_id" in kwargs: @@ -1195,6 +1195,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ()) maps.thread_resources.env = thread_resources_env + global shard_xmap, batch_xmap shard_xmap = __shard_xmap() batch_xmap = __batch_xmap(shard_dim=cores_per_replica) @@ -1244,6 +1245,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo from tqdm.auto import tqdm import functools + def callback(model_dict, f, **_): if callback.nested: return @@ -1317,19 +1319,20 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo #if "no_transpose" not in transforms and tensor.ndim == 2: # tensor = tensor.T tensor.unsqueeze_(0) - if tensor.dtype is torch.float16 or tensor.dtype is torch.float32: - tensor = tensor.bfloat16() + # Shard the tensor so that parts of the tensor can be used # on different TPU cores + tensor = reshard_reverse( + tensor, + params["cores_per_replica"], + network.state["params"][spec["module"]][spec["param"]].shape, + ) + tensor = jnp.array(tensor.detach()) + if tensor.dtype is torch.float16 or tensor.dtype is torch.float32: + tensor = tensor.bfloat16() network.state["params"][spec["module"]][spec["param"]] = move_xmap( - jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack( - reshard_reverse( - tensor, - params["cores_per_replica"], - network.state["params"][spec["module"]][spec["param"]].shape, - ) - )).copy(), + tensor, np.empty(params["cores_per_replica"]), ) @@ -1416,3 +1419,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo model = GPTNeoForCausalLM.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache") #network.state = network.move_xmap(network.state, np.zeros(cores_per_replica)) + global shard_xmap, batch_xmap + shard_xmap = __shard_xmap() + batch_xmap = __batch_xmap(shard_dim=cores_per_replica) \ No newline at end of file