diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 067f7912..9bb1fda2 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1149,7 +1149,8 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo params[param] = default_params[param] # Use an optimization that will allow us to avoid one extra transpose operation - params["transposed_linear"] = True + if hf_checkpoint: + params["transposed_linear"] = True # Load tokenizer if vars.model == "TPUMeshTransformerGPTNeoX":