From 9a3f0eaab27afd26ad45496392f162748797a2a6 Mon Sep 17 00:00:00 2001 From: vfbd Date: Mon, 21 Nov 2022 13:47:18 -0500 Subject: [PATCH] Only enable TPU transpose optimization if loading from HF model --- tpu_mtj_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 067f7912..9bb1fda2 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1149,7 +1149,8 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo params[param] = default_params[param] # Use an optimization that will allow us to avoid one extra transpose operation - params["transposed_linear"] = True + if hf_checkpoint: + params["transposed_linear"] = True # Load tokenizer if vars.model == "TPUMeshTransformerGPTNeoX":