Decrease TPU loading times by eliminating a transpose operation

This commit is contained in:
vfbd 2022-10-12 14:31:18 -04:00
parent 64715b18d6
commit bdc73ef393
1 changed files with 6 additions and 2 deletions

View File

@ -1148,6 +1148,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
if param not in params: if param not in params:
params[param] = default_params[param] params[param] = default_params[param]
# Use an optimization that will allow us to avoid one extra transpose operation
params["transposed_linear"] = True
# Load tokenizer # Load tokenizer
if vars.model == "TPUMeshTransformerGPTNeoX": if vars.model == "TPUMeshTransformerGPTNeoX":
tokenizer = Tokenizer.from_file(os.path.join(path, "20B_tokenizer.json")) tokenizer = Tokenizer.from_file(os.path.join(path, "20B_tokenizer.json"))
@ -1305,8 +1308,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
tensor /= params["cores_per_replica"] tensor /= params["cores_per_replica"]
if "vocab_pad" in transforms: if "vocab_pad" in transforms:
tensor = torch.nn.functional.pad(tensor, (0, 0, 0, params["n_vocab_padding"])) tensor = torch.nn.functional.pad(tensor, (0, 0, 0, params["n_vocab_padding"]))
if "no_transpose" not in transforms and tensor.ndim == 2: # We don't need to transpose linear module weights anymore because MTJ will do it for us if `transposed_linear` is set to True in the config
tensor = tensor.T #if "no_transpose" not in transforms and tensor.ndim == 2:
# tensor = tensor.T
tensor.unsqueeze_(0) tensor.unsqueeze_(0)
if tensor.dtype is torch.float16 or tensor.dtype is torch.float32: if tensor.dtype is torch.float16 or tensor.dtype is torch.float32:
tensor = tensor.bfloat16() tensor = tensor.bfloat16()