GPT-NeoX-20B support in Colab TPU instances

This commit is contained in:
Gnome Ann
2022-03-14 23:14:20 -04:00
parent 4892556059
commit 88f247d535
2 changed files with 166 additions and 22 deletions

View File

@ -46,6 +46,7 @@ import numpy as np
import optax
import haiku as hk
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
from tokenizers import Tokenizer
from mesh_transformer.checkpoint import read_ckpt_lowmem
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard, PlaceholderTensor
from mesh_transformer.util import to_bf16
@ -800,6 +801,121 @@ def reshard_reverse(x, total_shards, old_shape):
return out
def get_old_shape(t, total_shards, dim=2):
if len(t.shape) == 2:
shard_shape = t.shape
if dim == 1:
assert shard_shape[0] % total_shards == 0
return (shard_shape[0] // total_shards, shard_shape[1])
elif dim == 2:
assert shard_shape[1] % total_shards == 0
return (shard_shape[0], shard_shape[1] // total_shards)
else:
raise ValueError(f"Unsupported dim {dim}")
if len(t.shape) == 1:
assert t.shape[0] % total_shards == 0
return (t.shape[0] // total_shards,)
else:
raise ValueError(f"Unsupported shape {t.shape}")
def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
assert config["cores_per_replica"] % checkpoint_shards == 0
output_shards = config["cores_per_replica"] // checkpoint_shards
import torch
from tqdm.auto import tqdm
move_xmap = jax.experimental.maps.xmap(
fun=lambda x, _: to_bf16(x),
in_axes=(["shard", ...], ["batch", ...]),
out_axes=["shard", ...],
axis_resources={'shard': 'mp', 'batch': 'dp'}
)
path_template = os.path.join(path, "layer_{layer:02d}-model_{shard:02d}-model_states.pt")
static_mapping = {
"word_embeddings.weight": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1},
"final_linear.weight": {"module": "projection_shard/~/linear", "param": "w", "axis": 2},
"norm.weight": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale", "axis": None},
"norm.bias": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset", "axis": None},
}
layer_mapping = {
"attention.query_key_value.weight": {"module": "combined_qkv", "param": "w", "axis": 2},
"attention.query_key_value.bias": {"module": "combined_qkv", "param": "b", "axis": 1},
"attention.dense.weight": {"module": "linear_3", "param": "w", "axis": 1},
"attention.dense.bias": {"module": "linear_3", "param": "b", "axis": None},
"mlp.dense_h_to_4h.weight": {"module": "linear_4", "param": "w", "axis": 2},
"mlp.dense_h_to_4h.bias": {"module": "linear_4", "param": "b", "axis": 1},
"mlp.dense_4h_to_h.weight": {"module": "linear_5", "param": "w", "axis": 1},
"mlp.dense_4h_to_h.bias": {"module": "linear_5", "param": "b", "axis": None},
"input_layernorm.weight": {"module": "replicated_layer_norm", "param": "scale", "axis": None},
"input_layernorm.bias": {"module": "replicated_layer_norm", "param": "offset", "axis": None},
"post_attention_layernorm.weight": {"module": "replicated_layer_norm_1", "param": "scale", "axis": None},
"post_attention_layernorm.bias": {"module": "replicated_layer_norm_1", "param": "offset", "axis": None},
}
tqdm_length = len(static_mapping) + config["layers"]*len(layer_mapping)
bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint")
for checkpoint_layer in range(config["layers"] + 5):
if checkpoint_layer in (1, config["layers"] + 2):
continue
layer = checkpoint_layer - 2
shards = []
for checkpoint_shard in range(checkpoint_shards):
shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu"))
for key in shards[0]:
if key == "attention.rotary_emb.inv_freq":
continue
elif key in static_mapping:
target_module = "causal_transformer_shard/~/" + static_mapping[key]["module"]
target_param = static_mapping[key]["param"]
target_axis = static_mapping[key]["axis"]
elif key in layer_mapping:
target_module = f"causal_transformer_shard/~/layer_{layer}/~/" + layer_mapping[key]["module"]
target_param = layer_mapping[key]["param"]
target_axis = layer_mapping[key]["axis"]
else:
error = f"{repr(key)} not found in mapping"
print("\n\nERROR: ", error, file=sys.stderr)
raise RuntimeError(error)
original_shape = shards[0][key].shape
for checkpoint_shard in range(checkpoint_shards):
if key in ("attention.dense.bias", "mlp.dense_4h_to_h.bias"):
shards[checkpoint_shard][key] /= output_shards
if key != "word_embeddings.weight":
shards[checkpoint_shard][key] = shards[checkpoint_shard][key].T
tensor = shards[checkpoint_shard][key]
if target_axis is not None:
target_shape = (output_shards,) + get_old_shape(tensor, total_shards=output_shards, dim=target_axis)
else:
target_shape = (output_shards, tensor.shape[0])
shards[checkpoint_shard][key] = reshard_reverse(tensor.unsqueeze_(0), output_shards, target_shape)
#print(key, ":", original_shape, "->", shards[0][key].shape)
tensor = torch.cat([shards[s][key] for s in range(checkpoint_shards)], dim=0)
target_shape = state["params"][target_module][target_param].shape
if tensor.shape != target_shape:
error = f"Weight {repr(key)} has shape {tensor.shape} in checkpoint but shape {target_shape} was requested by MTJ for {target_module} {target_param}"
print("\n\nERROR: ", error, file=sys.stderr)
raise RuntimeError(error)
if tensor.dtype is torch.float16 or tensor.dtype is torch.float32:
tensor = tensor.bfloat16()
state["params"][target_module][target_param] = move_xmap(
jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(tensor)).copy(),
np.zeros(config["cores_per_replica"]),
)
bar.update(1)
for mk, mv in state["params"].items():
for pk, pv in mv.items():
if isinstance(pv, PlaceholderTensor):
error = f"{mk} {pk} could not be found in the model checkpoint"
print("\n\nERROR: " + error, file=sys.stderr)
raise RuntimeError(error)
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
global thread_resources_env, seq, tokenizer, network, params
@ -820,6 +936,23 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
}
params = kwargs
if vars.model == "TPUMeshTransformerGPTNeoX":
default_params = {
"compat": "neox",
"layers": 44,
"d_model": 6144,
"n_heads": 64,
"n_vocab": 50432,
"n_vocab_padding": 0,
"norm": "doublelayernorm",
"pe": "rotary",
"pe_rotary_dims": 24,
"seq": 2048,
"cores_per_replica": 8,
"tokenizer_class": "GPT2TokenizerFast",
"tokenizer": "gpt2",
}
# Try to convert HF config.json to MTJ config
if hf_checkpoint:
spec_path = os.path.join("maps", vars.model_type + ".json")
@ -875,7 +1008,14 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
params[param] = default_params[param]
# Load tokenizer
if not hf_checkpoint:
if vars.model == "TPUMeshTransformerGPTNeoX":
tokenizer = Tokenizer.from_file(os.path.join(path, "20B_tokenizer.json"))
def new_encode(old_encode):
def encode(s, *args, **kwargs):
return old_encode(s).ids
return encode
tokenizer.encode = new_encode(tokenizer.encode)
elif not hf_checkpoint:
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"])
@ -917,9 +1057,13 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
network = PenalizingCausalTransformer(params, dematerialized=True)
if not hf_checkpoint:
if not hf_checkpoint and vars.model != "TPUMeshTransformerGPTNeoX":
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
#network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
return
if vars.model == "TPUMeshTransformerGPTNeoX":
read_neox_checkpoint(network.state, path, params)
return
# Convert from HF checkpoint
@ -945,7 +1089,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
import torch_lazy_loader
import torch
from tqdm import tqdm
from tqdm.auto import tqdm
def callback(model_dict, f, **_):
with zipfile.ZipFile(f, "r") as z:
@ -1069,4 +1213,4 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
except Exception as e:
model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache")
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
#network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))