From 4502a3f6b8dfe49523a0b94e8f461b5ea6bb208d Mon Sep 17 00:00:00 2001 From: Bogdan Drema Date: Sat, 22 Apr 2023 22:31:21 +0100 Subject: [PATCH] Fix: TPU driver error to_dlpack/from_dlpack was causing issues with tensor with new jax version --- requirements_mtj.txt | 6 +++--- tpu_mtj_backend.py | 29 ++++++++++++++--------------- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/requirements_mtj.txt b/requirements_mtj.txt index 37b76a23..19da3910 100644 --- a/requirements_mtj.txt +++ b/requirements_mtj.txt @@ -2,9 +2,9 @@ torch >= 1.9, < 1.13 numpy tqdm requests -dm-haiku == 0.0.5 -jax == 0.2.21 -jaxlib >= 0.1.69, <= 0.3.7 +dm-haiku==0.0.9 +jax==0.3.25 +jaxlib==0.3.25 transformers == 4.28.0 chex == 0.1.5 huggingface_hub==0.12.1 diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 02754d95..dc0a664d 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1095,7 +1095,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2): koboldai_vars.status_message = "" -def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None: +def load_model(path: str, driver_version="tpu_driver_nightly", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None: global thread_resources_env, seq, tokenizer, network, params, pad_token_id if "pad_token_id" in kwargs: @@ -1270,11 +1270,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo logger.message(f"KoboldAI has finished loading and is available at the following link for UI 1: {koboldai_vars.cloudflare_link}") logger.message(f"KoboldAI has finished loading and is available at the following link for UI 2: {koboldai_vars.cloudflare_link}/new_ui") - - global shard_xmap, batch_xmap - shard_xmap = __shard_xmap() - batch_xmap = __batch_xmap(shard_dim=cores_per_replica) - global badwords # These are the tokens that we don't want the AI to ever write badwords = jnp.array(koboldai_vars.badwordsids).squeeze() @@ -1401,19 +1396,20 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo #if "no_transpose" not in transforms and tensor.ndim == 2: # tensor = tensor.T tensor.unsqueeze_(0) - if tensor.dtype is torch.float16 or tensor.dtype is torch.float32: - tensor = tensor.bfloat16() + # Shard the tensor so that parts of the tensor can be used # on different TPU cores + tensor = reshard_reverse( + tensor, + params["cores_per_replica"], + network.state["params"][spec["module"]][spec["param"]].shape, + ) + tensor = jnp.array(tensor.detach()) + if tensor.dtype is torch.float16 or tensor.dtype is torch.float32: + tensor = tensor.bfloat16() network.state["params"][spec["module"]][spec["param"]] = move_xmap( - jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack( - reshard_reverse( - tensor, - params["cores_per_replica"], - network.state["params"][spec["module"]][spec["param"]].shape, - ) - )).copy(), + tensor, np.empty(params["cores_per_replica"]), ) @@ -1506,3 +1502,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo model = GPTNeoForCausalLM.from_pretrained(koboldai_vars.model, revision=koboldai_vars.revision, cache_dir="cache") #network.state = network.move_xmap(network.state, np.zeros(cores_per_replica)) + global shard_xmap, batch_xmap + shard_xmap = __shard_xmap() + batch_xmap = __batch_xmap(shard_dim=cores_per_replica) \ No newline at end of file