Only enable TPU transpose optimization if loading from HF model

This commit is contained in:
vfbd 2022-11-21 13:47:18 -05:00
parent f2077b8e58
commit 9a3f0eaab2

View File

@ -1149,7 +1149,8 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
params[param] = default_params[param] params[param] = default_params[param]
# Use an optimization that will allow us to avoid one extra transpose operation # Use an optimization that will allow us to avoid one extra transpose operation
params["transposed_linear"] = True if hf_checkpoint:
params["transposed_linear"] = True
# Load tokenizer # Load tokenizer
if vars.model == "TPUMeshTransformerGPTNeoX": if vars.model == "TPUMeshTransformerGPTNeoX":