diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 205f9cc4..fe20646c 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1049,7 +1049,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2): raise RuntimeError(error) -def load_model(path: str, driver_version="tpu_driver_nightly", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None: +def load_model(path: str, driver_version="tpu_driver_20221109", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None: global thread_resources_env, seq, tokenizer, network, params, pad_token_id if "pad_token_id" in kwargs: