Pin driver to the one from JAX 0.3.25

This commit is contained in:
henk717
2023-04-23 01:43:37 +02:00
committed by GitHub
parent df0ab18696
commit 8f44141f96

View File

@@ -1095,7 +1095,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
koboldai_vars.status_message = "" koboldai_vars.status_message = ""
def load_model(path: str, driver_version="tpu_driver_nightly", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None: def load_model(path: str, driver_version="tpu_driver_20221109", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None:
global thread_resources_env, seq, tokenizer, network, params, pad_token_id global thread_resources_env, seq, tokenizer, network, params, pad_token_id
if "pad_token_id" in kwargs: if "pad_token_id" in kwargs: