Pin TPU driver

This commit is contained in:
henk717 2023-04-23 20:21:28 +02:00 committed by GitHub
parent d88f109073
commit b808f039ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -1049,7 +1049,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
raise RuntimeError(error)
def load_model(path: str, driver_version="tpu_driver_nightly", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None:
def load_model(path: str, driver_version="tpu_driver_20221109", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None:
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
if "pad_token_id" in kwargs: