From 8f44141f965941fa768a668991d11247a410149a Mon Sep 17 00:00:00 2001 From: henk717 Date: Sun, 23 Apr 2023 01:43:37 +0200 Subject: [PATCH] Pin driver to the one from JAX 0.3.25 --- tpu_mtj_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index dc0a664d..d8edb92f 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1095,7 +1095,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2): koboldai_vars.status_message = "" -def load_model(path: str, driver_version="tpu_driver_nightly", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None: +def load_model(path: str, driver_version="tpu_driver_20221109", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None: global thread_resources_env, seq, tokenizer, network, params, pad_token_id if "pad_token_id" in kwargs: @@ -1504,4 +1504,4 @@ def load_model(path: str, driver_version="tpu_driver_nightly", hf_checkpoint=Fal #network.state = network.move_xmap(network.state, np.zeros(cores_per_replica)) global shard_xmap, batch_xmap shard_xmap = __shard_xmap() - batch_xmap = __batch_xmap(shard_dim=cores_per_replica) \ No newline at end of file + batch_xmap = __batch_xmap(shard_dim=cores_per_replica)