diff --git a/aiserver.py b/aiserver.py index bc5cc3ba..41899642 100644 --- a/aiserver.py +++ b/aiserver.py @@ -317,7 +317,7 @@ def getmodelname(): if(args.configname): modelname = args.configname return modelname - if(vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ")): + if(vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): modelname = os.path.basename(os.path.normpath(vars.custmodpth)) return modelname else: @@ -699,7 +699,7 @@ def spRequest(filename): vars.sp_length = tensor.shape[-2] vars.spmeta["n_tokens"] = vars.sp_length - if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): + if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): rows = tensor.shape[0] padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) @@ -790,7 +790,7 @@ else: getModelSelection(mainmenu) # If transformers model was selected & GPU available, ask to use CPU or GPU -if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]): +if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]): vars.allowsp = True # Test for GPU support import torch @@ -830,7 +830,7 @@ if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMe print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)") vars.model_type = "gpt_neo" -if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]): +if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]): loadmodelsettings() loadsettings() print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="") @@ -1032,7 +1032,7 @@ socketio = SocketIO(app, async_method="eventlet") print("{0}OK!{1}".format(colors.GREEN, colors.END)) # Start transformers and create pipeline -if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]): +if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]): if(not vars.noai): print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END)) from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer @@ -1050,7 +1050,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go if not vars.lazy_load: return - from tqdm import tqdm + from tqdm.auto import tqdm if "breakmodel" in globals(): gpu_blocks = breakmodel.gpu_blocks @@ -1553,9 +1553,9 @@ else: tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/") loadsettings() # Load the TPU backend if requested - elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): + elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END)) - if vars.model == "TPUMeshTransformerGPTJ" and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)): + if vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)): raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder") import tpu_mtj_backend tpu_mtj_backend.vars = vars @@ -1567,7 +1567,7 @@ else: vars.allowsp = True loadmodelsettings() loadsettings() - tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=vars.model != "TPUMeshTransformerGPTJ" and vars.use_colab_tpu, **vars.modelconfig) + tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and vars.use_colab_tpu, **vars.modelconfig) vars.modeldim = int(tpu_mtj_backend.params["d_model"]) tokenizer = tpu_mtj_backend.tokenizer else: @@ -2098,7 +2098,7 @@ def lua_get_modeltype(): return "readonly" if(vars.model in ("Colab", "OAI", "InferKit")): return "api" - if(not vars.use_colab_tpu and vars.model not in ("TPUMeshTransformerGPTJ",) and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))): + if(not vars.use_colab_tpu and vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))): hidden_size = get_hidden_size_from_model(model) if(vars.model in ("gpt2",) or (vars.model_type == "gpt2" and hidden_size == 768)): return "gpt2" @@ -2127,7 +2127,7 @@ def lua_get_modelbackend(): return "readonly" if(vars.model in ("Colab", "OAI", "InferKit")): return "api" - if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): + if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): return "mtj" return "transformers" @@ -2136,7 +2136,7 @@ def lua_get_modelbackend(): #==================================================================# @bridged_kwarg() def lua_is_custommodel(): - return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ") + return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") #==================================================================# # @@ -3074,22 +3074,22 @@ def calcsubmit(txt): if(vars.model != "InferKit"): subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions, submission=txt) if(actionlen == 0): - if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): + if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]): generate(subtxt, min, max, found_entries=found_entries) elif(vars.model == "Colab"): sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) elif(vars.model == "OAI"): oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) - elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): + elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): tpumtjgenerate(subtxt, min, max, found_entries=found_entries) else: - if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): + if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]): generate(subtxt, min, max, found_entries=found_entries) elif(vars.model == "Colab"): sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) elif(vars.model == "OAI"): oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) - elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): + elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): tpumtjgenerate(subtxt, min, max, found_entries=found_entries) # For InferKit web API @@ -5105,7 +5105,7 @@ if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")): file.close() # Precompile TPU backend if required -if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): +if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): soft_tokens = tpumtjgetsofttokens() if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)): threading.Thread( diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index b13e3aa3..b36b88ba 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -46,6 +46,7 @@ import numpy as np import optax import haiku as hk from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM +from tokenizers import Tokenizer from mesh_transformer.checkpoint import read_ckpt_lowmem from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard, PlaceholderTensor from mesh_transformer.util import to_bf16 @@ -800,6 +801,121 @@ def reshard_reverse(x, total_shards, old_shape): return out +def get_old_shape(t, total_shards, dim=2): + if len(t.shape) == 2: + shard_shape = t.shape + if dim == 1: + assert shard_shape[0] % total_shards == 0 + return (shard_shape[0] // total_shards, shard_shape[1]) + elif dim == 2: + assert shard_shape[1] % total_shards == 0 + return (shard_shape[0], shard_shape[1] // total_shards) + else: + raise ValueError(f"Unsupported dim {dim}") + if len(t.shape) == 1: + assert t.shape[0] % total_shards == 0 + return (t.shape[0] // total_shards,) + else: + raise ValueError(f"Unsupported shape {t.shape}") + + +def read_neox_checkpoint(state, path, config, checkpoint_shards=2): + assert config["cores_per_replica"] % checkpoint_shards == 0 + output_shards = config["cores_per_replica"] // checkpoint_shards + + import torch + from tqdm.auto import tqdm + + move_xmap = jax.experimental.maps.xmap( + fun=lambda x, _: to_bf16(x), + in_axes=(["shard", ...], ["batch", ...]), + out_axes=["shard", ...], + axis_resources={'shard': 'mp', 'batch': 'dp'} + ) + + path_template = os.path.join(path, "layer_{layer:02d}-model_{shard:02d}-model_states.pt") + + static_mapping = { + "word_embeddings.weight": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1}, + "final_linear.weight": {"module": "projection_shard/~/linear", "param": "w", "axis": 2}, + "norm.weight": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale", "axis": None}, + "norm.bias": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset", "axis": None}, + } + + layer_mapping = { + "attention.query_key_value.weight": {"module": "combined_qkv", "param": "w", "axis": 2}, + "attention.query_key_value.bias": {"module": "combined_qkv", "param": "b", "axis": 1}, + "attention.dense.weight": {"module": "linear_3", "param": "w", "axis": 1}, + "attention.dense.bias": {"module": "linear_3", "param": "b", "axis": None}, + "mlp.dense_h_to_4h.weight": {"module": "linear_4", "param": "w", "axis": 2}, + "mlp.dense_h_to_4h.bias": {"module": "linear_4", "param": "b", "axis": 1}, + "mlp.dense_4h_to_h.weight": {"module": "linear_5", "param": "w", "axis": 1}, + "mlp.dense_4h_to_h.bias": {"module": "linear_5", "param": "b", "axis": None}, + "input_layernorm.weight": {"module": "replicated_layer_norm", "param": "scale", "axis": None}, + "input_layernorm.bias": {"module": "replicated_layer_norm", "param": "offset", "axis": None}, + "post_attention_layernorm.weight": {"module": "replicated_layer_norm_1", "param": "scale", "axis": None}, + "post_attention_layernorm.bias": {"module": "replicated_layer_norm_1", "param": "offset", "axis": None}, + } + + tqdm_length = len(static_mapping) + config["layers"]*len(layer_mapping) + bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint") + + for checkpoint_layer in range(config["layers"] + 5): + if checkpoint_layer in (1, config["layers"] + 2): + continue + layer = checkpoint_layer - 2 + shards = [] + for checkpoint_shard in range(checkpoint_shards): + shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu")) + for key in shards[0]: + if key == "attention.rotary_emb.inv_freq": + continue + elif key in static_mapping: + target_module = "causal_transformer_shard/~/" + static_mapping[key]["module"] + target_param = static_mapping[key]["param"] + target_axis = static_mapping[key]["axis"] + elif key in layer_mapping: + target_module = f"causal_transformer_shard/~/layer_{layer}/~/" + layer_mapping[key]["module"] + target_param = layer_mapping[key]["param"] + target_axis = layer_mapping[key]["axis"] + else: + error = f"{repr(key)} not found in mapping" + print("\n\nERROR: ", error, file=sys.stderr) + raise RuntimeError(error) + original_shape = shards[0][key].shape + for checkpoint_shard in range(checkpoint_shards): + if key in ("attention.dense.bias", "mlp.dense_4h_to_h.bias"): + shards[checkpoint_shard][key] /= output_shards + if key != "word_embeddings.weight": + shards[checkpoint_shard][key] = shards[checkpoint_shard][key].T + tensor = shards[checkpoint_shard][key] + if target_axis is not None: + target_shape = (output_shards,) + get_old_shape(tensor, total_shards=output_shards, dim=target_axis) + else: + target_shape = (output_shards, tensor.shape[0]) + shards[checkpoint_shard][key] = reshard_reverse(tensor.unsqueeze_(0), output_shards, target_shape) + #print(key, ":", original_shape, "->", shards[0][key].shape) + tensor = torch.cat([shards[s][key] for s in range(checkpoint_shards)], dim=0) + target_shape = state["params"][target_module][target_param].shape + if tensor.shape != target_shape: + error = f"Weight {repr(key)} has shape {tensor.shape} in checkpoint but shape {target_shape} was requested by MTJ for {target_module} {target_param}" + print("\n\nERROR: ", error, file=sys.stderr) + raise RuntimeError(error) + if tensor.dtype is torch.float16 or tensor.dtype is torch.float32: + tensor = tensor.bfloat16() + state["params"][target_module][target_param] = move_xmap( + jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(tensor)).copy(), + np.zeros(config["cores_per_replica"]), + ) + bar.update(1) + for mk, mv in state["params"].items(): + for pk, pv in mv.items(): + if isinstance(pv, PlaceholderTensor): + error = f"{mk} {pk} could not be found in the model checkpoint" + print("\n\nERROR: " + error, file=sys.stderr) + raise RuntimeError(error) + + def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None: global thread_resources_env, seq, tokenizer, network, params @@ -820,6 +936,23 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo } params = kwargs + if vars.model == "TPUMeshTransformerGPTNeoX": + default_params = { + "compat": "neox", + "layers": 44, + "d_model": 6144, + "n_heads": 64, + "n_vocab": 50432, + "n_vocab_padding": 0, + "norm": "doublelayernorm", + "pe": "rotary", + "pe_rotary_dims": 24, + "seq": 2048, + "cores_per_replica": 8, + "tokenizer_class": "GPT2TokenizerFast", + "tokenizer": "gpt2", + } + # Try to convert HF config.json to MTJ config if hf_checkpoint: spec_path = os.path.join("maps", vars.model_type + ".json") @@ -875,7 +1008,14 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo params[param] = default_params[param] # Load tokenizer - if not hf_checkpoint: + if vars.model == "TPUMeshTransformerGPTNeoX": + tokenizer = Tokenizer.from_file(os.path.join(path, "20B_tokenizer.json")) + def new_encode(old_encode): + def encode(s, *args, **kwargs): + return old_encode(s).ids + return encode + tokenizer.encode = new_encode(tokenizer.encode) + elif not hf_checkpoint: if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")): raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'") tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"]) @@ -917,9 +1057,13 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo network = PenalizingCausalTransformer(params, dematerialized=True) - if not hf_checkpoint: + if not hf_checkpoint and vars.model != "TPUMeshTransformerGPTNeoX": network.state = read_ckpt_lowmem(network.state, path, devices.shape[1]) - network.state = network.move_xmap(network.state, np.zeros(cores_per_replica)) + #network.state = network.move_xmap(network.state, np.zeros(cores_per_replica)) + return + + if vars.model == "TPUMeshTransformerGPTNeoX": + read_neox_checkpoint(network.state, path, params) return # Convert from HF checkpoint @@ -945,7 +1089,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo import torch_lazy_loader import torch - from tqdm import tqdm + from tqdm.auto import tqdm def callback(model_dict, f, **_): with zipfile.ZipFile(f, "r") as z: @@ -1069,4 +1213,4 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo except Exception as e: model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache") - network.state = network.move_xmap(network.state, np.zeros(cores_per_replica)) + #network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))