diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index d992ba45..99efa05f 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1148,6 +1148,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo if param not in params: params[param] = default_params[param] + # Use an optimization that will allow us to avoid one extra transpose operation + params["transposed_linear"] = True + # Load tokenizer if vars.model == "TPUMeshTransformerGPTNeoX": tokenizer = Tokenizer.from_file(os.path.join(path, "20B_tokenizer.json")) @@ -1305,8 +1308,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo tensor /= params["cores_per_replica"] if "vocab_pad" in transforms: tensor = torch.nn.functional.pad(tensor, (0, 0, 0, params["n_vocab_padding"])) - if "no_transpose" not in transforms and tensor.ndim == 2: - tensor = tensor.T + # We don't need to transpose linear module weights anymore because MTJ will do it for us if `transposed_linear` is set to True in the config + #if "no_transpose" not in transforms and tensor.ndim == 2: + # tensor = tensor.T tensor.unsqueeze_(0) if tensor.dtype is torch.float16 or tensor.dtype is torch.float32: tensor = tensor.bfloat16()