diff --git a/aiserver.py b/aiserver.py index e78211ae..2d1998e6 100644 --- a/aiserver.py +++ b/aiserver.py @@ -695,7 +695,7 @@ def spRequest(filename): vars.sp_length = tensor.shape[-2] vars.spmeta["n_tokens"] = vars.sp_length - if(vars.model in ("TPUMeshTransformerGPTJ",)): + if(args.colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): rows = tensor.shape[0] padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) @@ -730,6 +730,7 @@ parser.add_argument("--override_delete", action='store_true', help="Deleting sto parser.add_argument("--override_rename", action='store_true', help="Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise.") parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.") parser.add_argument("--colab", action='store_true', help="Optimize for Google Colab.") +parser.add_argument("--colab_tpu", action='store_true', help="If you're running KoboldAI in a Google Colab TPU instance, enable this to load Hugging Face models onto the TPU.") parser.add_argument("--nobreakmodel", action='store_true', help="Disables Breakmodel support completely.") parser.add_argument("--unblock", action='store_true', default=False, help="Unblocks the KoboldAI port to be accessible from other machines without optimizing for remote play (It is recommended to use --host instead)") parser.add_argument("--quiet", action='store_true', default=False, help="If present will suppress any story related text from showing on the console") @@ -783,7 +784,7 @@ else: getModelSelection(mainmenu) # If transformers model was selected & GPU available, ask to use CPU or GPU -if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): +if(vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): vars.allowsp = True # Test for GPU support import torch @@ -822,6 +823,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme elif(vars.model_type == "not_found"): print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)") vars.model_type = "gpt_neo" + +if(not args.colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): loadmodelsettings() loadsettings() print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="") @@ -1014,7 +1017,7 @@ socketio = SocketIO(app, async_method="eventlet") print("{0}OK!{1}".format(colors.GREEN, colors.END)) # Start transformers and create pipeline -if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): +if(not args.colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): if(not vars.noai): print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END)) from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer @@ -1523,9 +1526,9 @@ else: tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/") loadsettings() # Load the TPU backend if requested - elif(vars.model == "TPUMeshTransformerGPTJ"): + elif(args.colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END)) - if not vars.custmodpth or not os.path.isdir(vars.custmodpth): + if vars.model == "TPUMeshTransformerGPTJ" and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)): raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder") import tpu_mtj_backend tpu_mtj_backend.vars = vars @@ -1537,7 +1540,7 @@ else: vars.allowsp = True loadmodelsettings() loadsettings() - tpu_mtj_backend.load_model(vars.custmodpth, **vars.modelconfig) + tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=args.colab_tpu, **vars.modelconfig) vars.modeldim = int(tpu_mtj_backend.params["d_model"]) tokenizer = tpu_mtj_backend.tokenizer else: @@ -2068,7 +2071,7 @@ def lua_get_modeltype(): return "readonly" if(vars.model in ("Colab", "OAI", "InferKit")): return "api" - if(vars.model not in ("TPUMeshTransformerGPTJ",) and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))): + if(not args.colab_tpu and vars.model not in ("TPUMeshTransformerGPTJ",) and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))): hidden_size = get_hidden_size_from_model(model) if(vars.model in ("gpt2",) or (vars.model_type == "gpt2" and hidden_size == 768)): return "gpt2" @@ -2084,7 +2087,7 @@ def lua_get_modeltype(): return "gpt-neo-1.3B" if(vars.model in ("EleutherAI/gpt-neo-2.7B",) or (vars.model_type == "gpt_neo" and hidden_size == 2560)): return "gpt-neo-2.7B" - if(vars.model in ("EleutherAI/gpt-j-6B",) or (vars.model == "TPUMeshTransformerGPTJ" and tpu_mtj_backend.params["d_model"] == 4096) or (vars.model_type in ("gpt_neo", "gptj") and hidden_size == 4096)): + if(vars.model in ("EleutherAI/gpt-j-6B",) or ((args.colab_tpu or vars.model == "TPUMeshTransformerGPTJ") and tpu_mtj_backend.params["d_model"] == 4096) or (vars.model_type in ("gpt_neo", "gptj") and hidden_size == 4096)): return "gpt-j-6B" return "unknown" @@ -2097,7 +2100,7 @@ def lua_get_modelbackend(): return "readonly" if(vars.model in ("Colab", "OAI", "InferKit")): return "api" - if(vars.model in ("TPUMeshTransformerGPTJ",)): + if(args.colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): return "mtj" return "transformers" @@ -3044,22 +3047,22 @@ def calcsubmit(txt): if(vars.model != "InferKit"): subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions, submission=txt) if(actionlen == 0): - if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): + if(not args.colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): generate(subtxt, min, max, found_entries=found_entries) elif(vars.model == "Colab"): sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) elif(vars.model == "OAI"): oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) - elif(vars.model == "TPUMeshTransformerGPTJ"): + elif(args.colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): tpumtjgenerate(subtxt, min, max, found_entries=found_entries) else: - if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): + if(not args.colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): generate(subtxt, min, max, found_entries=found_entries) elif(vars.model == "Colab"): sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) elif(vars.model == "OAI"): oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) - elif(vars.model == "TPUMeshTransformerGPTJ"): + elif(args.colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): tpumtjgenerate(subtxt, min, max, found_entries=found_entries) # For InferKit web API @@ -5071,7 +5074,7 @@ if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")): file.close() # Precompile TPU backend if required -if(vars.model in ("TPUMeshTransformerGPTJ",)): +if(args.colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): soft_tokens = tpumtjgetsofttokens() if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)): threading.Thread( diff --git a/maps/gpt_neo.json b/maps/gpt_neo.json index a93fb26e..fa2d4084 100644 --- a/maps/gpt_neo.json +++ b/maps/gpt_neo.json @@ -1,25 +1,32 @@ { + "mtj_compat": "neo", + "mtj_pe": "fixed", + "mtj_config_map": { + "d_model": "hidden_size", + "n_heads": "num_heads", + "layers": "num_layers" + }, "static_weights": { - "transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}}, - "transformer.wpe.weight": {"mtj": {"module": "embedding_shard/~/pos_embs", "param": "w", "axis": 2, "transforms": ["transpose"]}}, - "transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}}, - "transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}} + "transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}}, + "transformer.wpe.weight": {"mtj": {"module": "embedding_shard", "param": "pos_embs", "transforms": ["no_transpose"]}}, + "transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}}, + "transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}} }, "layer_weights": { "transformer.h.{layer}.attn.attention.bias": {}, "transformer.h.{layer}.attn.attention.masked_bias": {}, - "transformer.h.{layer}.attn.attention.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}}, - "transformer.h.{layer}.attn.attention.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}}, - "transformer.h.{layer}.attn.attention.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}}, - "transformer.h.{layer}.attn.attention.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}}, + "transformer.h.{layer}.attn.attention.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}}, + "transformer.h.{layer}.attn.attention.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}}, + "transformer.h.{layer}.attn.attention.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}}, + "transformer.h.{layer}.attn.attention.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}}, "transformer.h.{layer}.attn.attention.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}}, - "transformer.h.{layer}.mlp.c_fc.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}}, - "transformer.h.{layer}.mlp.c_fc.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}}, - "transformer.h.{layer}.mlp.c_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}}, + "transformer.h.{layer}.mlp.c_fc.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}}, + "transformer.h.{layer}.mlp.c_fc.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}}, + "transformer.h.{layer}.mlp.c_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}}, "transformer.h.{layer}.mlp.c_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}}, - "transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}}, - "transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}}, - "transformer.h.{layer}.ln_2.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "scale"}}, - "transformer.h.{layer}.ln_2.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "offset"}} + "transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}}, + "transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}}, + "transformer.h.{layer}.ln_2.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "scale"}}, + "transformer.h.{layer}.ln_2.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "offset"}} } } diff --git a/maps/gptj.json b/maps/gptj.json index 51a788d7..8e0bc9da 100644 --- a/maps/gptj.json +++ b/maps/gptj.json @@ -1,24 +1,32 @@ { + "mtj_compat": "j", + "mtj_pe": "rotary", + "mtj_config_map": { + "pe_rotary_dims": ["rotary_dim", 64], + "d_model": "n_embd", + "n_heads": "n_head", + "layers": "n_layer" + }, "static_weights": { - "transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}}, + "transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}}, "transformer.wte.bias": {"mtj": {"module": "embedding_shard/~/linear", "param": "b"}}, - "transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}}, - "transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}}, - "lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "axis": 2, "transforms": ["vocab_pad"]}}, + "transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}}, + "transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}}, + "lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "transforms": ["vocab_pad"]}}, "lm_head.bias": {"mtj": {"module": "projection_shard/~/linear", "param": "b"}} }, "layer_weights": { "transformer.h.{layer}.attn.bias": {}, "transformer.h.{layer}.attn.masked_bias": {}, - "transformer.h.{layer}.attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}}, - "transformer.h.{layer}.attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}}, - "transformer.h.{layer}.attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}}, - "transformer.h.{layer}.attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}}, - "transformer.h.{layer}.mlp.fc_in.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}}, - "transformer.h.{layer}.mlp.fc_in.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}}, - "transformer.h.{layer}.mlp.fc_out.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}}, + "transformer.h.{layer}.attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}}, + "transformer.h.{layer}.attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}}, + "transformer.h.{layer}.attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}}, + "transformer.h.{layer}.attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}}, + "transformer.h.{layer}.mlp.fc_in.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}}, + "transformer.h.{layer}.mlp.fc_in.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}}, + "transformer.h.{layer}.mlp.fc_out.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}}, "transformer.h.{layer}.mlp.fc_out.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}}, - "transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}}, - "transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}} + "transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}}, + "transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}} } } diff --git a/maps/xglm.json b/maps/xglm.json index beb90985..3ba4b1f2 100644 --- a/maps/xglm.json +++ b/maps/xglm.json @@ -1,26 +1,33 @@ { + "mtj_compat": "fairseq_lm", + "mtj_pe": "fairseq_sinusoidal", + "mtj_config_map": { + "d_model": "d_model", + "n_heads": "attention_heads", + "layers": "num_layers" + }, "static_weights": { - "model.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}}, - "model.layer_norm.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}}, - "model.layer_norm.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}}, - "lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "axis": 2, "transforms": ["vocab_pad"]}} + "model.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}}, + "model.replicated_layer_norm.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}}, + "model.replicated_layer_norm.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}}, + "lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "transforms": ["vocab_pad"]}} }, "layer_weights": { - "model.layers.{layer}.self_attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}}, - "model.layers.{layer}.self_attn.q_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear", "param": "b", "axis": 1}}, - "model.layers.{layer}.self_attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}}, - "model.layers.{layer}.self_attn.v_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "b", "axis": 1}}, - "model.layers.{layer}.self_attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}}, - "model.layers.{layer}.self_attn.k_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "b", "axis": 1}}, - "model.layers.{layer}.self_attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}}, + "model.layers.{layer}.self_attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}}, + "model.layers.{layer}.self_attn.q_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear", "param": "b"}}, + "model.layers.{layer}.self_attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}}, + "model.layers.{layer}.self_attn.v_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "b"}}, + "model.layers.{layer}.self_attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}}, + "model.layers.{layer}.self_attn.k_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "b"}}, + "model.layers.{layer}.self_attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}}, "model.layers.{layer}.self_attn.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}}, - "model.layers.{layer}.fc1.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}}, - "model.layers.{layer}.fc1.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}}, - "model.layers.{layer}.fc2.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}}, + "model.layers.{layer}.fc1.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}}, + "model.layers.{layer}.fc1.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}}, + "model.layers.{layer}.fc2.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}}, "model.layers.{layer}.fc2.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}}, - "model.layers.{layer}.self_attn_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}}, - "model.layers.{layer}.self_attn_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}}, - "model.layers.{layer}.final_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "scale"}}, - "model.layers.{layer}.final_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "offset"}} + "model.layers.{layer}.self_attn_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}}, + "model.layers.{layer}.self_attn_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}}, + "model.layers.{layer}.final_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "scale"}}, + "model.layers.{layer}.final_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "offset"}} } } diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index a78f93f2..6f5500b7 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -32,6 +32,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar import progressbar import time import os +import sys +import json +import zipfile import requests import random import jax @@ -41,9 +44,10 @@ import jax.numpy as jnp import numpy as np import optax import haiku as hk -import transformers +from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM from mesh_transformer.checkpoint import read_ckpt_lowmem -from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard +from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard, PlaceholderTensor +from mesh_transformer.util import to_bf16 params: Dict[str, Any] = {} @@ -776,7 +780,26 @@ def infer_static( return samples -def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) -> None: +def reshard_reverse(x, total_shards, old_shape): + assert len(x.shape) != 1 + if len(x.shape) == 2: + if old_shape[1] == x.shape[1]: + out = x[0:1].tile((total_shards, 1)) + else: + out = x.reshape(old_shape) + elif len(x.shape) == 3: + if x.shape[0] * x.shape[2] == old_shape[2]: + out = x.reshape(old_shape) + elif x.shape[0] * x.shape[1] == old_shape[1]: + out = x.reshape((old_shape[1], old_shape[0], old_shape[2])).permute((1, 0, 2)) + else: + assert False + else: + assert False + return out + + +def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None: global thread_resources_env, seq, tokenizer, network, params default_params = { @@ -795,6 +818,53 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) "tokenizer": "gpt2", } params = kwargs + + # Try to convert HF config.json to MTJ config + if hf_checkpoint: + spec_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "maps", vars.model_type + ".json") + if not os.path.isfile(spec_path): + raise NotImplementedError(f"Unsupported model type {repr(vars.model_type)}") + with open(spec_path) as f: + lazy_load_spec = json.load(f) + + if "mtj_compat" in lazy_load_spec: + params["compat"] = lazy_load_spec["mtj_compat"] + if "mtj_pe" in lazy_load_spec: + params["pe"] = lazy_load_spec["mtj_pe"] + for k, v in lazy_load_spec.get("mtj_config_map", {}).items(): + if type(v) is not list: + params[k] = params[v] + continue + for i in range(len(v)): + if i == len(v) - 1: + params[k] = v[i] + elif v[i] in params: + params[k] = params[v[i]] + break + + params["n_vocab"] = params["vocab_size"] + + if "activation_function" in params: + params["activation"] = params["activation_function"] + + # Both the number of attention heads in the model and the embedding + # dimension of the model need to be divisible by the number of TPU cores + # that we use, and JAX also requires the number of TPU cores used to be + # an even number if we're using more than one core, so logically we try + # to pick the largest possible even number of TPU cores such that the + # number of attention heads and embedding dimension are both divisible + # by the number of TPU cores, and fall back to one core if an even + # number of TPU cores is not possible. + for c in (8, 6, 4, 2, 1): + if 0 == params["n_heads"] % c == params["d_model"] % c: + params["cores_per_replica"] = c + break + + # The vocabulary size of the model also has to be divisible by the + # number of TPU cores, so we pad the vocabulary with the minimum + # possible number of dummy tokens such that it's divisible. + params["n_vocab_padding"] = -(params["n_vocab"] % -params["cores_per_replica"]) + if "compat" in params: default_params["compat"] = params["compat"] if default_params["compat"] == "fairseq_lm": @@ -804,10 +874,11 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) params[param] = default_params[param] # Load tokenizer - if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")): - raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'") - tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"]) - tokenizer = tokenizer_class.from_pretrained(params["tokenizer"]) + if not hf_checkpoint: + if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")): + raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'") + tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"]) + tokenizer = tokenizer_class.from_pretrained(params["tokenizer"]) # Disable JAX warnings about these two functions having been renamed jax.host_count = jax.process_count @@ -844,5 +915,147 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) path += "/" network = PenalizingCausalTransformer(params, dematerialized=True) - network.state = read_ckpt_lowmem(network.state, path, devices.shape[1]) + + if not hf_checkpoint: + network.state = read_ckpt_lowmem(network.state, path, devices.shape[1]) + network.state = network.move_xmap(network.state, np.zeros(cores_per_replica)) + return + + # Convert from HF checkpoint + + move_xmap = jax.experimental.maps.xmap( + fun=lambda x, _: to_bf16(x), + in_axes=(["shard", ...], ["batch", ...]), + out_axes=["shard", ...], + axis_resources={'shard': 'mp', 'batch': 'dp'} + ) + + model_spec = {} + for key, spec in lazy_load_spec.get("static_weights", {}).items(): + if spec.get("mtj") is not None: + model_spec[key] = spec["mtj"].copy() + model_spec[key]["module"] = "causal_transformer_shard/~/" + model_spec[key]["module"] + for _key, spec in lazy_load_spec.get("layer_weights", {}).items(): + for layer in range(params["layers"]): + if spec.get("mtj") is not None: + key = _key.format(layer=layer) + model_spec[key] = spec["mtj"].copy() + model_spec[key]["module"] = "causal_transformer_shard/~/" + model_spec[key]["module"].format(layer=layer) + + import torch_lazy_loader + import torch + from tqdm import tqdm + + def callback(model_dict, f, **_): + with zipfile.ZipFile(f, "r") as z: + try: + last_storage_key = None + f = None + print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n") + for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"): + + # Some model weights are used by transformers but not by MTJ. + # We have to materialize these weights anyways because + # transformers will throw a tantrum otherwise. To attain + # the least possible memory usage, we create them as meta + # tensors, which don't take up any actual CPU or TPU memory. + if key not in model_spec: + model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].storage_type(0).dtype, device="meta") + continue + + storage_key = model_dict[key].key + if storage_key != last_storage_key: + last_storage_key = storage_key + if isinstance(f, zipfile.ZipExtFile): + f.close() + f = z.open(f"archive/data/{storage_key}") + current_offset = f.tell() + if current_offset != model_dict[key].seek_offset: + f.seek(model_dict[key].seek_offset - current_offset, 1) + spec = model_spec[key] + transforms = set(spec.get("transforms", ())) + if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor): + error = f"Duplicate key {repr(key)}" + print("\n\nERROR: " + error, file=sys.stderr) + raise RuntimeError(error) + tensor = model_dict[key].materialize(f, map_location="cpu") + model_dict[key] = tensor.to("meta") + + # MTJ requires certain mathematical operations to be performed + # on tensors in order for them to be in the correct format + if "divide_by_shards" in transforms: + tensor /= params["cores_per_replica"] + if "vocab_pad" in transforms: + tensor = torch.nn.functional.pad(tensor, (0, 0, 0, params["n_vocab_padding"])) + if "no_transpose" not in transforms: + tensor = tensor.T + tensor.unsqueeze_(0) + + # Shard the tensor so that parts of the tensor can be used + # on different TPU cores + network.state["params"][spec["module"]][spec["param"]] = move_xmap( + jnp.array( + reshard_reverse( + tensor, + params["cores_per_replica"], + network.state["params"][spec["module"]][spec["param"]].shape, + ), + dtype=jnp.bfloat16, + ), + np.empty(params["cores_per_replica"]), + ) + + # Check for tensors that MTJ needs that were not provided in the + # HF model + for mk, mv in network.state["params"].items(): + for pk, pv in mv.items(): + if isinstance(pv, PlaceholderTensor): + # The transformers GPT-J models apparently do not + # have embedding bias, whereas MTJ GPT-J models do, + # so we have to supplement an embedding bias tensor + # by creating a tensor with the necessary shape, filled + # with zeros. + if mk == "causal_transformer_shard/~/embedding_shard/~/linear" and pk == "b": + mv[pk] = move_xmap(jnp.zeros(mv[pk].shape, dtype=jnp.bfloat16), np.empty(params["cores_per_replica"])) + + else: + error = f"{mk} {pk} could not be found in the model checkpoint" + print("\n\nERROR: " + error, file=sys.stderr) + raise RuntimeError(error) + finally: + if isinstance(f, zipfile.ZipExtFile): + f.close() + + if os.path.isdir(vars.model.replace('/', '_')): + import shutil + shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_'))) + with torch_lazy_loader.use_lazy_torch_load(callback=callback, dematerialized_modules=True): + if(os.path.isdir(vars.custmodpth)): + try: + tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache") + except ValueError as e: + tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache") + try: + model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache") + except ValueError as e: + model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache") + elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))): + try: + tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") + except ValueError as e: + tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") + try: + model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) + except ValueError as e: + model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) + else: + try: + tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache") + except ValueError as e: + tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache") + try: + model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache") + except ValueError as e: + model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache") + network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))