Support for loading HF models on TPU with `--colab_tpu`

This commit is contained in:
Gnome Ann 2022-03-05 12:33:33 -05:00
parent 86ac562b0c
commit 0a258a6282
5 changed files with 306 additions and 68 deletions

View File

@ -695,7 +695,7 @@ def spRequest(filename):
vars.sp_length = tensor.shape[-2] vars.sp_length = tensor.shape[-2]
vars.spmeta["n_tokens"] = vars.sp_length vars.spmeta["n_tokens"] = vars.sp_length
if(vars.model in ("TPUMeshTransformerGPTJ",)): if(args.colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)):
rows = tensor.shape[0] rows = tensor.shape[0]
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
@ -730,6 +730,7 @@ parser.add_argument("--override_delete", action='store_true', help="Deleting sto
parser.add_argument("--override_rename", action='store_true', help="Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise.") parser.add_argument("--override_rename", action='store_true', help="Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise.")
parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.") parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.")
parser.add_argument("--colab", action='store_true', help="Optimize for Google Colab.") parser.add_argument("--colab", action='store_true', help="Optimize for Google Colab.")
parser.add_argument("--colab_tpu", action='store_true', help="If you're running KoboldAI in a Google Colab TPU instance, enable this to load Hugging Face models onto the TPU.")
parser.add_argument("--nobreakmodel", action='store_true', help="Disables Breakmodel support completely.") parser.add_argument("--nobreakmodel", action='store_true', help="Disables Breakmodel support completely.")
parser.add_argument("--unblock", action='store_true', default=False, help="Unblocks the KoboldAI port to be accessible from other machines without optimizing for remote play (It is recommended to use --host instead)") parser.add_argument("--unblock", action='store_true', default=False, help="Unblocks the KoboldAI port to be accessible from other machines without optimizing for remote play (It is recommended to use --host instead)")
parser.add_argument("--quiet", action='store_true', default=False, help="If present will suppress any story related text from showing on the console") parser.add_argument("--quiet", action='store_true', default=False, help="If present will suppress any story related text from showing on the console")
@ -783,7 +784,7 @@ else:
getModelSelection(mainmenu) getModelSelection(mainmenu)
# If transformers model was selected & GPU available, ask to use CPU or GPU # If transformers model was selected & GPU available, ask to use CPU or GPU
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): if(vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
vars.allowsp = True vars.allowsp = True
# Test for GPU support # Test for GPU support
import torch import torch
@ -822,6 +823,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
elif(vars.model_type == "not_found"): elif(vars.model_type == "not_found"):
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)") print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
vars.model_type = "gpt_neo" vars.model_type = "gpt_neo"
if(not args.colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
loadmodelsettings() loadmodelsettings()
loadsettings() loadsettings()
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="") print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
@ -1014,7 +1017,7 @@ socketio = SocketIO(app, async_method="eventlet")
print("{0}OK!{1}".format(colors.GREEN, colors.END)) print("{0}OK!{1}".format(colors.GREEN, colors.END))
# Start transformers and create pipeline # Start transformers and create pipeline
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): if(not args.colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
if(not vars.noai): if(not vars.noai):
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END)) print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
@ -1523,9 +1526,9 @@ else:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/") tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
loadsettings() loadsettings()
# Load the TPU backend if requested # Load the TPU backend if requested
elif(vars.model == "TPUMeshTransformerGPTJ"): elif(args.colab_tpu or vars.model == "TPUMeshTransformerGPTJ"):
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END)) print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
if not vars.custmodpth or not os.path.isdir(vars.custmodpth): if vars.model == "TPUMeshTransformerGPTJ" and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)):
raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder") raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder")
import tpu_mtj_backend import tpu_mtj_backend
tpu_mtj_backend.vars = vars tpu_mtj_backend.vars = vars
@ -1537,7 +1540,7 @@ else:
vars.allowsp = True vars.allowsp = True
loadmodelsettings() loadmodelsettings()
loadsettings() loadsettings()
tpu_mtj_backend.load_model(vars.custmodpth, **vars.modelconfig) tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=args.colab_tpu, **vars.modelconfig)
vars.modeldim = int(tpu_mtj_backend.params["d_model"]) vars.modeldim = int(tpu_mtj_backend.params["d_model"])
tokenizer = tpu_mtj_backend.tokenizer tokenizer = tpu_mtj_backend.tokenizer
else: else:
@ -2068,7 +2071,7 @@ def lua_get_modeltype():
return "readonly" return "readonly"
if(vars.model in ("Colab", "OAI", "InferKit")): if(vars.model in ("Colab", "OAI", "InferKit")):
return "api" return "api"
if(vars.model not in ("TPUMeshTransformerGPTJ",) and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))): if(not args.colab_tpu and vars.model not in ("TPUMeshTransformerGPTJ",) and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))):
hidden_size = get_hidden_size_from_model(model) hidden_size = get_hidden_size_from_model(model)
if(vars.model in ("gpt2",) or (vars.model_type == "gpt2" and hidden_size == 768)): if(vars.model in ("gpt2",) or (vars.model_type == "gpt2" and hidden_size == 768)):
return "gpt2" return "gpt2"
@ -2084,7 +2087,7 @@ def lua_get_modeltype():
return "gpt-neo-1.3B" return "gpt-neo-1.3B"
if(vars.model in ("EleutherAI/gpt-neo-2.7B",) or (vars.model_type == "gpt_neo" and hidden_size == 2560)): if(vars.model in ("EleutherAI/gpt-neo-2.7B",) or (vars.model_type == "gpt_neo" and hidden_size == 2560)):
return "gpt-neo-2.7B" return "gpt-neo-2.7B"
if(vars.model in ("EleutherAI/gpt-j-6B",) or (vars.model == "TPUMeshTransformerGPTJ" and tpu_mtj_backend.params["d_model"] == 4096) or (vars.model_type in ("gpt_neo", "gptj") and hidden_size == 4096)): if(vars.model in ("EleutherAI/gpt-j-6B",) or ((args.colab_tpu or vars.model == "TPUMeshTransformerGPTJ") and tpu_mtj_backend.params["d_model"] == 4096) or (vars.model_type in ("gpt_neo", "gptj") and hidden_size == 4096)):
return "gpt-j-6B" return "gpt-j-6B"
return "unknown" return "unknown"
@ -2097,7 +2100,7 @@ def lua_get_modelbackend():
return "readonly" return "readonly"
if(vars.model in ("Colab", "OAI", "InferKit")): if(vars.model in ("Colab", "OAI", "InferKit")):
return "api" return "api"
if(vars.model in ("TPUMeshTransformerGPTJ",)): if(args.colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)):
return "mtj" return "mtj"
return "transformers" return "transformers"
@ -3044,22 +3047,22 @@ def calcsubmit(txt):
if(vars.model != "InferKit"): if(vars.model != "InferKit"):
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions, submission=txt) subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions, submission=txt)
if(actionlen == 0): if(actionlen == 0):
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): if(not args.colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
generate(subtxt, min, max, found_entries=found_entries) generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"): elif(vars.model == "Colab"):
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.model == "OAI"): elif(vars.model == "OAI"):
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.model == "TPUMeshTransformerGPTJ"): elif(args.colab_tpu or vars.model == "TPUMeshTransformerGPTJ"):
tpumtjgenerate(subtxt, min, max, found_entries=found_entries) tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
else: else:
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): if(not args.colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
generate(subtxt, min, max, found_entries=found_entries) generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"): elif(vars.model == "Colab"):
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.model == "OAI"): elif(vars.model == "OAI"):
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.model == "TPUMeshTransformerGPTJ"): elif(args.colab_tpu or vars.model == "TPUMeshTransformerGPTJ"):
tpumtjgenerate(subtxt, min, max, found_entries=found_entries) tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
# For InferKit web API # For InferKit web API
@ -5071,7 +5074,7 @@ if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
file.close() file.close()
# Precompile TPU backend if required # Precompile TPU backend if required
if(vars.model in ("TPUMeshTransformerGPTJ",)): if(args.colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)):
soft_tokens = tpumtjgetsofttokens() soft_tokens = tpumtjgetsofttokens()
if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)): if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)):
threading.Thread( threading.Thread(

View File

@ -1,25 +1,32 @@
{ {
"mtj_compat": "neo",
"mtj_pe": "fixed",
"mtj_config_map": {
"d_model": "hidden_size",
"n_heads": "num_heads",
"layers": "num_layers"
},
"static_weights": { "static_weights": {
"transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}}, "transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}},
"transformer.wpe.weight": {"mtj": {"module": "embedding_shard/~/pos_embs", "param": "w", "axis": 2, "transforms": ["transpose"]}}, "transformer.wpe.weight": {"mtj": {"module": "embedding_shard", "param": "pos_embs", "transforms": ["no_transpose"]}},
"transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}}, "transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}},
"transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}} "transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}}
}, },
"layer_weights": { "layer_weights": {
"transformer.h.{layer}.attn.attention.bias": {}, "transformer.h.{layer}.attn.attention.bias": {},
"transformer.h.{layer}.attn.attention.masked_bias": {}, "transformer.h.{layer}.attn.attention.masked_bias": {},
"transformer.h.{layer}.attn.attention.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}}, "transformer.h.{layer}.attn.attention.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}},
"transformer.h.{layer}.attn.attention.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}}, "transformer.h.{layer}.attn.attention.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}},
"transformer.h.{layer}.attn.attention.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}}, "transformer.h.{layer}.attn.attention.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}},
"transformer.h.{layer}.attn.attention.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}}, "transformer.h.{layer}.attn.attention.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}},
"transformer.h.{layer}.attn.attention.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}}, "transformer.h.{layer}.attn.attention.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}},
"transformer.h.{layer}.mlp.c_fc.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}}, "transformer.h.{layer}.mlp.c_fc.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}},
"transformer.h.{layer}.mlp.c_fc.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}}, "transformer.h.{layer}.mlp.c_fc.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}},
"transformer.h.{layer}.mlp.c_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}}, "transformer.h.{layer}.mlp.c_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}},
"transformer.h.{layer}.mlp.c_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}}, "transformer.h.{layer}.mlp.c_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}},
"transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}}, "transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}},
"transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}}, "transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}},
"transformer.h.{layer}.ln_2.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "scale"}}, "transformer.h.{layer}.ln_2.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "scale"}},
"transformer.h.{layer}.ln_2.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "offset"}} "transformer.h.{layer}.ln_2.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "offset"}}
} }
} }

View File

@ -1,24 +1,32 @@
{ {
"mtj_compat": "j",
"mtj_pe": "rotary",
"mtj_config_map": {
"pe_rotary_dims": ["rotary_dim", 64],
"d_model": "n_embd",
"n_heads": "n_head",
"layers": "n_layer"
},
"static_weights": { "static_weights": {
"transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}}, "transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}},
"transformer.wte.bias": {"mtj": {"module": "embedding_shard/~/linear", "param": "b"}}, "transformer.wte.bias": {"mtj": {"module": "embedding_shard/~/linear", "param": "b"}},
"transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}}, "transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}},
"transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}}, "transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}},
"lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "axis": 2, "transforms": ["vocab_pad"]}}, "lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "transforms": ["vocab_pad"]}},
"lm_head.bias": {"mtj": {"module": "projection_shard/~/linear", "param": "b"}} "lm_head.bias": {"mtj": {"module": "projection_shard/~/linear", "param": "b"}}
}, },
"layer_weights": { "layer_weights": {
"transformer.h.{layer}.attn.bias": {}, "transformer.h.{layer}.attn.bias": {},
"transformer.h.{layer}.attn.masked_bias": {}, "transformer.h.{layer}.attn.masked_bias": {},
"transformer.h.{layer}.attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}}, "transformer.h.{layer}.attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}},
"transformer.h.{layer}.attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}}, "transformer.h.{layer}.attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}},
"transformer.h.{layer}.attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}}, "transformer.h.{layer}.attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}},
"transformer.h.{layer}.attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}}, "transformer.h.{layer}.attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}},
"transformer.h.{layer}.mlp.fc_in.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}}, "transformer.h.{layer}.mlp.fc_in.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}},
"transformer.h.{layer}.mlp.fc_in.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}}, "transformer.h.{layer}.mlp.fc_in.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}},
"transformer.h.{layer}.mlp.fc_out.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}}, "transformer.h.{layer}.mlp.fc_out.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}},
"transformer.h.{layer}.mlp.fc_out.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}}, "transformer.h.{layer}.mlp.fc_out.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}},
"transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}}, "transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}},
"transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}} "transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}}
} }
} }

View File

@ -1,26 +1,33 @@
{ {
"mtj_compat": "fairseq_lm",
"mtj_pe": "fairseq_sinusoidal",
"mtj_config_map": {
"d_model": "d_model",
"n_heads": "attention_heads",
"layers": "num_layers"
},
"static_weights": { "static_weights": {
"model.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}}, "model.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}},
"model.layer_norm.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}}, "model.replicated_layer_norm.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}},
"model.layer_norm.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}}, "model.replicated_layer_norm.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}},
"lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "axis": 2, "transforms": ["vocab_pad"]}} "lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "transforms": ["vocab_pad"]}}
}, },
"layer_weights": { "layer_weights": {
"model.layers.{layer}.self_attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}}, "model.layers.{layer}.self_attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}},
"model.layers.{layer}.self_attn.q_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear", "param": "b", "axis": 1}}, "model.layers.{layer}.self_attn.q_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear", "param": "b"}},
"model.layers.{layer}.self_attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}}, "model.layers.{layer}.self_attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}},
"model.layers.{layer}.self_attn.v_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "b", "axis": 1}}, "model.layers.{layer}.self_attn.v_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "b"}},
"model.layers.{layer}.self_attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}}, "model.layers.{layer}.self_attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}},
"model.layers.{layer}.self_attn.k_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "b", "axis": 1}}, "model.layers.{layer}.self_attn.k_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "b"}},
"model.layers.{layer}.self_attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}}, "model.layers.{layer}.self_attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}},
"model.layers.{layer}.self_attn.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}}, "model.layers.{layer}.self_attn.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}},
"model.layers.{layer}.fc1.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}}, "model.layers.{layer}.fc1.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}},
"model.layers.{layer}.fc1.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}}, "model.layers.{layer}.fc1.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}},
"model.layers.{layer}.fc2.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}}, "model.layers.{layer}.fc2.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}},
"model.layers.{layer}.fc2.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}}, "model.layers.{layer}.fc2.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}},
"model.layers.{layer}.self_attn_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}}, "model.layers.{layer}.self_attn_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}},
"model.layers.{layer}.self_attn_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}}, "model.layers.{layer}.self_attn_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}},
"model.layers.{layer}.final_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "scale"}}, "model.layers.{layer}.final_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "scale"}},
"model.layers.{layer}.final_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "offset"}} "model.layers.{layer}.final_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "offset"}}
} }
} }

View File

@ -32,6 +32,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
import progressbar import progressbar
import time import time
import os import os
import sys
import json
import zipfile
import requests import requests
import random import random
import jax import jax
@ -41,9 +44,10 @@ import jax.numpy as jnp
import numpy as np import numpy as np
import optax import optax
import haiku as hk import haiku as hk
import transformers from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
from mesh_transformer.checkpoint import read_ckpt_lowmem from mesh_transformer.checkpoint import read_ckpt_lowmem
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard, PlaceholderTensor
from mesh_transformer.util import to_bf16
params: Dict[str, Any] = {} params: Dict[str, Any] = {}
@ -776,7 +780,26 @@ def infer_static(
return samples return samples
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) -> None: def reshard_reverse(x, total_shards, old_shape):
assert len(x.shape) != 1
if len(x.shape) == 2:
if old_shape[1] == x.shape[1]:
out = x[0:1].tile((total_shards, 1))
else:
out = x.reshape(old_shape)
elif len(x.shape) == 3:
if x.shape[0] * x.shape[2] == old_shape[2]:
out = x.reshape(old_shape)
elif x.shape[0] * x.shape[1] == old_shape[1]:
out = x.reshape((old_shape[1], old_shape[0], old_shape[2])).permute((1, 0, 2))
else:
assert False
else:
assert False
return out
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
global thread_resources_env, seq, tokenizer, network, params global thread_resources_env, seq, tokenizer, network, params
default_params = { default_params = {
@ -795,6 +818,53 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
"tokenizer": "gpt2", "tokenizer": "gpt2",
} }
params = kwargs params = kwargs
# Try to convert HF config.json to MTJ config
if hf_checkpoint:
spec_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "maps", vars.model_type + ".json")
if not os.path.isfile(spec_path):
raise NotImplementedError(f"Unsupported model type {repr(vars.model_type)}")
with open(spec_path) as f:
lazy_load_spec = json.load(f)
if "mtj_compat" in lazy_load_spec:
params["compat"] = lazy_load_spec["mtj_compat"]
if "mtj_pe" in lazy_load_spec:
params["pe"] = lazy_load_spec["mtj_pe"]
for k, v in lazy_load_spec.get("mtj_config_map", {}).items():
if type(v) is not list:
params[k] = params[v]
continue
for i in range(len(v)):
if i == len(v) - 1:
params[k] = v[i]
elif v[i] in params:
params[k] = params[v[i]]
break
params["n_vocab"] = params["vocab_size"]
if "activation_function" in params:
params["activation"] = params["activation_function"]
# Both the number of attention heads in the model and the embedding
# dimension of the model need to be divisible by the number of TPU cores
# that we use, and JAX also requires the number of TPU cores used to be
# an even number if we're using more than one core, so logically we try
# to pick the largest possible even number of TPU cores such that the
# number of attention heads and embedding dimension are both divisible
# by the number of TPU cores, and fall back to one core if an even
# number of TPU cores is not possible.
for c in (8, 6, 4, 2, 1):
if 0 == params["n_heads"] % c == params["d_model"] % c:
params["cores_per_replica"] = c
break
# The vocabulary size of the model also has to be divisible by the
# number of TPU cores, so we pad the vocabulary with the minimum
# possible number of dummy tokens such that it's divisible.
params["n_vocab_padding"] = -(params["n_vocab"] % -params["cores_per_replica"])
if "compat" in params: if "compat" in params:
default_params["compat"] = params["compat"] default_params["compat"] = params["compat"]
if default_params["compat"] == "fairseq_lm": if default_params["compat"] == "fairseq_lm":
@ -804,10 +874,11 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
params[param] = default_params[param] params[param] = default_params[param]
# Load tokenizer # Load tokenizer
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")): if not hf_checkpoint:
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'") if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"]) raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
tokenizer = tokenizer_class.from_pretrained(params["tokenizer"]) tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"])
tokenizer = tokenizer_class.from_pretrained(params["tokenizer"])
# Disable JAX warnings about these two functions having been renamed # Disable JAX warnings about these two functions having been renamed
jax.host_count = jax.process_count jax.host_count = jax.process_count
@ -844,5 +915,147 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
path += "/" path += "/"
network = PenalizingCausalTransformer(params, dematerialized=True) network = PenalizingCausalTransformer(params, dematerialized=True)
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
if not hf_checkpoint:
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
return
# Convert from HF checkpoint
move_xmap = jax.experimental.maps.xmap(
fun=lambda x, _: to_bf16(x),
in_axes=(["shard", ...], ["batch", ...]),
out_axes=["shard", ...],
axis_resources={'shard': 'mp', 'batch': 'dp'}
)
model_spec = {}
for key, spec in lazy_load_spec.get("static_weights", {}).items():
if spec.get("mtj") is not None:
model_spec[key] = spec["mtj"].copy()
model_spec[key]["module"] = "causal_transformer_shard/~/" + model_spec[key]["module"]
for _key, spec in lazy_load_spec.get("layer_weights", {}).items():
for layer in range(params["layers"]):
if spec.get("mtj") is not None:
key = _key.format(layer=layer)
model_spec[key] = spec["mtj"].copy()
model_spec[key]["module"] = "causal_transformer_shard/~/" + model_spec[key]["module"].format(layer=layer)
import torch_lazy_loader
import torch
from tqdm import tqdm
def callback(model_dict, f, **_):
with zipfile.ZipFile(f, "r") as z:
try:
last_storage_key = None
f = None
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
# Some model weights are used by transformers but not by MTJ.
# We have to materialize these weights anyways because
# transformers will throw a tantrum otherwise. To attain
# the least possible memory usage, we create them as meta
# tensors, which don't take up any actual CPU or TPU memory.
if key not in model_spec:
model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].storage_type(0).dtype, device="meta")
continue
storage_key = model_dict[key].key
if storage_key != last_storage_key:
last_storage_key = storage_key
if isinstance(f, zipfile.ZipExtFile):
f.close()
f = z.open(f"archive/data/{storage_key}")
current_offset = f.tell()
if current_offset != model_dict[key].seek_offset:
f.seek(model_dict[key].seek_offset - current_offset, 1)
spec = model_spec[key]
transforms = set(spec.get("transforms", ()))
if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor):
error = f"Duplicate key {repr(key)}"
print("\n\nERROR: " + error, file=sys.stderr)
raise RuntimeError(error)
tensor = model_dict[key].materialize(f, map_location="cpu")
model_dict[key] = tensor.to("meta")
# MTJ requires certain mathematical operations to be performed
# on tensors in order for them to be in the correct format
if "divide_by_shards" in transforms:
tensor /= params["cores_per_replica"]
if "vocab_pad" in transforms:
tensor = torch.nn.functional.pad(tensor, (0, 0, 0, params["n_vocab_padding"]))
if "no_transpose" not in transforms:
tensor = tensor.T
tensor.unsqueeze_(0)
# Shard the tensor so that parts of the tensor can be used
# on different TPU cores
network.state["params"][spec["module"]][spec["param"]] = move_xmap(
jnp.array(
reshard_reverse(
tensor,
params["cores_per_replica"],
network.state["params"][spec["module"]][spec["param"]].shape,
),
dtype=jnp.bfloat16,
),
np.empty(params["cores_per_replica"]),
)
# Check for tensors that MTJ needs that were not provided in the
# HF model
for mk, mv in network.state["params"].items():
for pk, pv in mv.items():
if isinstance(pv, PlaceholderTensor):
# The transformers GPT-J models apparently do not
# have embedding bias, whereas MTJ GPT-J models do,
# so we have to supplement an embedding bias tensor
# by creating a tensor with the necessary shape, filled
# with zeros.
if mk == "causal_transformer_shard/~/embedding_shard/~/linear" and pk == "b":
mv[pk] = move_xmap(jnp.zeros(mv[pk].shape, dtype=jnp.bfloat16), np.empty(params["cores_per_replica"]))
else:
error = f"{mk} {pk} could not be found in the model checkpoint"
print("\n\nERROR: " + error, file=sys.stderr)
raise RuntimeError(error)
finally:
if isinstance(f, zipfile.ZipExtFile):
f.close()
if os.path.isdir(vars.model.replace('/', '_')):
import shutil
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
with torch_lazy_loader.use_lazy_torch_load(callback=callback, dematerialized_modules=True):
if(os.path.isdir(vars.custmodpth)):
try:
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache")
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache")
elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))):
try:
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
else:
try:
tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache")
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache")
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica)) network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))