Support for loading HF models on TPU with `--colab_tpu`

This commit is contained in:
Gnome Ann 2022-03-05 12:33:33 -05:00
parent 86ac562b0c
commit 0a258a6282
5 changed files with 306 additions and 68 deletions

View File

@ -695,7 +695,7 @@ def spRequest(filename):
vars.sp_length = tensor.shape[-2]
vars.spmeta["n_tokens"] = vars.sp_length
if(vars.model in ("TPUMeshTransformerGPTJ",)):
if(args.colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)):
rows = tensor.shape[0]
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
@ -730,6 +730,7 @@ parser.add_argument("--override_delete", action='store_true', help="Deleting sto
parser.add_argument("--override_rename", action='store_true', help="Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise.")
parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.")
parser.add_argument("--colab", action='store_true', help="Optimize for Google Colab.")
parser.add_argument("--colab_tpu", action='store_true', help="If you're running KoboldAI in a Google Colab TPU instance, enable this to load Hugging Face models onto the TPU.")
parser.add_argument("--nobreakmodel", action='store_true', help="Disables Breakmodel support completely.")
parser.add_argument("--unblock", action='store_true', default=False, help="Unblocks the KoboldAI port to be accessible from other machines without optimizing for remote play (It is recommended to use --host instead)")
parser.add_argument("--quiet", action='store_true', default=False, help="If present will suppress any story related text from showing on the console")
@ -783,7 +784,7 @@ else:
getModelSelection(mainmenu)
# If transformers model was selected & GPU available, ask to use CPU or GPU
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
if(vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
vars.allowsp = True
# Test for GPU support
import torch
@ -822,6 +823,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
elif(vars.model_type == "not_found"):
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
vars.model_type = "gpt_neo"
if(not args.colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
loadmodelsettings()
loadsettings()
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
@ -1014,7 +1017,7 @@ socketio = SocketIO(app, async_method="eventlet")
print("{0}OK!{1}".format(colors.GREEN, colors.END))
# Start transformers and create pipeline
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
if(not args.colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
if(not vars.noai):
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
@ -1523,9 +1526,9 @@ else:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
loadsettings()
# Load the TPU backend if requested
elif(vars.model == "TPUMeshTransformerGPTJ"):
elif(args.colab_tpu or vars.model == "TPUMeshTransformerGPTJ"):
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
if not vars.custmodpth or not os.path.isdir(vars.custmodpth):
if vars.model == "TPUMeshTransformerGPTJ" and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)):
raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder")
import tpu_mtj_backend
tpu_mtj_backend.vars = vars
@ -1537,7 +1540,7 @@ else:
vars.allowsp = True
loadmodelsettings()
loadsettings()
tpu_mtj_backend.load_model(vars.custmodpth, **vars.modelconfig)
tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=args.colab_tpu, **vars.modelconfig)
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
tokenizer = tpu_mtj_backend.tokenizer
else:
@ -2068,7 +2071,7 @@ def lua_get_modeltype():
return "readonly"
if(vars.model in ("Colab", "OAI", "InferKit")):
return "api"
if(vars.model not in ("TPUMeshTransformerGPTJ",) and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))):
if(not args.colab_tpu and vars.model not in ("TPUMeshTransformerGPTJ",) and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))):
hidden_size = get_hidden_size_from_model(model)
if(vars.model in ("gpt2",) or (vars.model_type == "gpt2" and hidden_size == 768)):
return "gpt2"
@ -2084,7 +2087,7 @@ def lua_get_modeltype():
return "gpt-neo-1.3B"
if(vars.model in ("EleutherAI/gpt-neo-2.7B",) or (vars.model_type == "gpt_neo" and hidden_size == 2560)):
return "gpt-neo-2.7B"
if(vars.model in ("EleutherAI/gpt-j-6B",) or (vars.model == "TPUMeshTransformerGPTJ" and tpu_mtj_backend.params["d_model"] == 4096) or (vars.model_type in ("gpt_neo", "gptj") and hidden_size == 4096)):
if(vars.model in ("EleutherAI/gpt-j-6B",) or ((args.colab_tpu or vars.model == "TPUMeshTransformerGPTJ") and tpu_mtj_backend.params["d_model"] == 4096) or (vars.model_type in ("gpt_neo", "gptj") and hidden_size == 4096)):
return "gpt-j-6B"
return "unknown"
@ -2097,7 +2100,7 @@ def lua_get_modelbackend():
return "readonly"
if(vars.model in ("Colab", "OAI", "InferKit")):
return "api"
if(vars.model in ("TPUMeshTransformerGPTJ",)):
if(args.colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)):
return "mtj"
return "transformers"
@ -3044,22 +3047,22 @@ def calcsubmit(txt):
if(vars.model != "InferKit"):
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions, submission=txt)
if(actionlen == 0):
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
if(not args.colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"):
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.model == "OAI"):
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.model == "TPUMeshTransformerGPTJ"):
elif(args.colab_tpu or vars.model == "TPUMeshTransformerGPTJ"):
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
else:
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
if(not args.colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"):
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.model == "OAI"):
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.model == "TPUMeshTransformerGPTJ"):
elif(args.colab_tpu or vars.model == "TPUMeshTransformerGPTJ"):
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
# For InferKit web API
@ -5071,7 +5074,7 @@ if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
file.close()
# Precompile TPU backend if required
if(vars.model in ("TPUMeshTransformerGPTJ",)):
if(args.colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)):
soft_tokens = tpumtjgetsofttokens()
if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)):
threading.Thread(

View File

@ -1,25 +1,32 @@
{
"mtj_compat": "neo",
"mtj_pe": "fixed",
"mtj_config_map": {
"d_model": "hidden_size",
"n_heads": "num_heads",
"layers": "num_layers"
},
"static_weights": {
"transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}},
"transformer.wpe.weight": {"mtj": {"module": "embedding_shard/~/pos_embs", "param": "w", "axis": 2, "transforms": ["transpose"]}},
"transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}},
"transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}}
"transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}},
"transformer.wpe.weight": {"mtj": {"module": "embedding_shard", "param": "pos_embs", "transforms": ["no_transpose"]}},
"transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}},
"transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}}
},
"layer_weights": {
"transformer.h.{layer}.attn.attention.bias": {},
"transformer.h.{layer}.attn.attention.masked_bias": {},
"transformer.h.{layer}.attn.attention.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}},
"transformer.h.{layer}.attn.attention.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}},
"transformer.h.{layer}.attn.attention.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}},
"transformer.h.{layer}.attn.attention.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}},
"transformer.h.{layer}.attn.attention.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}},
"transformer.h.{layer}.attn.attention.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}},
"transformer.h.{layer}.attn.attention.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}},
"transformer.h.{layer}.attn.attention.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}},
"transformer.h.{layer}.attn.attention.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}},
"transformer.h.{layer}.mlp.c_fc.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}},
"transformer.h.{layer}.mlp.c_fc.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}},
"transformer.h.{layer}.mlp.c_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}},
"transformer.h.{layer}.mlp.c_fc.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}},
"transformer.h.{layer}.mlp.c_fc.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}},
"transformer.h.{layer}.mlp.c_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}},
"transformer.h.{layer}.mlp.c_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}},
"transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}},
"transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}},
"transformer.h.{layer}.ln_2.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "scale"}},
"transformer.h.{layer}.ln_2.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "offset"}}
"transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}},
"transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}},
"transformer.h.{layer}.ln_2.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "scale"}},
"transformer.h.{layer}.ln_2.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "offset"}}
}
}

View File

@ -1,24 +1,32 @@
{
"mtj_compat": "j",
"mtj_pe": "rotary",
"mtj_config_map": {
"pe_rotary_dims": ["rotary_dim", 64],
"d_model": "n_embd",
"n_heads": "n_head",
"layers": "n_layer"
},
"static_weights": {
"transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}},
"transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}},
"transformer.wte.bias": {"mtj": {"module": "embedding_shard/~/linear", "param": "b"}},
"transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}},
"transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}},
"lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "axis": 2, "transforms": ["vocab_pad"]}},
"transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}},
"transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}},
"lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "transforms": ["vocab_pad"]}},
"lm_head.bias": {"mtj": {"module": "projection_shard/~/linear", "param": "b"}}
},
"layer_weights": {
"transformer.h.{layer}.attn.bias": {},
"transformer.h.{layer}.attn.masked_bias": {},
"transformer.h.{layer}.attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}},
"transformer.h.{layer}.attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}},
"transformer.h.{layer}.attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}},
"transformer.h.{layer}.attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}},
"transformer.h.{layer}.mlp.fc_in.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}},
"transformer.h.{layer}.mlp.fc_in.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}},
"transformer.h.{layer}.mlp.fc_out.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}},
"transformer.h.{layer}.attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}},
"transformer.h.{layer}.attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}},
"transformer.h.{layer}.attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}},
"transformer.h.{layer}.attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}},
"transformer.h.{layer}.mlp.fc_in.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}},
"transformer.h.{layer}.mlp.fc_in.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}},
"transformer.h.{layer}.mlp.fc_out.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}},
"transformer.h.{layer}.mlp.fc_out.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}},
"transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}},
"transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}}
"transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}},
"transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}}
}
}

View File

@ -1,26 +1,33 @@
{
"mtj_compat": "fairseq_lm",
"mtj_pe": "fairseq_sinusoidal",
"mtj_config_map": {
"d_model": "d_model",
"n_heads": "attention_heads",
"layers": "num_layers"
},
"static_weights": {
"model.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}},
"model.layer_norm.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}},
"model.layer_norm.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}},
"lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "axis": 2, "transforms": ["vocab_pad"]}}
"model.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}},
"model.replicated_layer_norm.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}},
"model.replicated_layer_norm.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}},
"lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "transforms": ["vocab_pad"]}}
},
"layer_weights": {
"model.layers.{layer}.self_attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}},
"model.layers.{layer}.self_attn.q_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear", "param": "b", "axis": 1}},
"model.layers.{layer}.self_attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}},
"model.layers.{layer}.self_attn.v_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "b", "axis": 1}},
"model.layers.{layer}.self_attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}},
"model.layers.{layer}.self_attn.k_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "b", "axis": 1}},
"model.layers.{layer}.self_attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}},
"model.layers.{layer}.self_attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}},
"model.layers.{layer}.self_attn.q_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear", "param": "b"}},
"model.layers.{layer}.self_attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}},
"model.layers.{layer}.self_attn.v_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "b"}},
"model.layers.{layer}.self_attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}},
"model.layers.{layer}.self_attn.k_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "b"}},
"model.layers.{layer}.self_attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}},
"model.layers.{layer}.self_attn.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}},
"model.layers.{layer}.fc1.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}},
"model.layers.{layer}.fc1.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}},
"model.layers.{layer}.fc2.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}},
"model.layers.{layer}.fc1.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}},
"model.layers.{layer}.fc1.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}},
"model.layers.{layer}.fc2.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}},
"model.layers.{layer}.fc2.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}},
"model.layers.{layer}.self_attn_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}},
"model.layers.{layer}.self_attn_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}},
"model.layers.{layer}.final_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "scale"}},
"model.layers.{layer}.final_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "offset"}}
"model.layers.{layer}.self_attn_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}},
"model.layers.{layer}.self_attn_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}},
"model.layers.{layer}.final_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "scale"}},
"model.layers.{layer}.final_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "offset"}}
}
}

View File

@ -32,6 +32,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
import progressbar
import time
import os
import sys
import json
import zipfile
import requests
import random
import jax
@ -41,9 +44,10 @@ import jax.numpy as jnp
import numpy as np
import optax
import haiku as hk
import transformers
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
from mesh_transformer.checkpoint import read_ckpt_lowmem
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard, PlaceholderTensor
from mesh_transformer.util import to_bf16
params: Dict[str, Any] = {}
@ -776,7 +780,26 @@ def infer_static(
return samples
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) -> None:
def reshard_reverse(x, total_shards, old_shape):
assert len(x.shape) != 1
if len(x.shape) == 2:
if old_shape[1] == x.shape[1]:
out = x[0:1].tile((total_shards, 1))
else:
out = x.reshape(old_shape)
elif len(x.shape) == 3:
if x.shape[0] * x.shape[2] == old_shape[2]:
out = x.reshape(old_shape)
elif x.shape[0] * x.shape[1] == old_shape[1]:
out = x.reshape((old_shape[1], old_shape[0], old_shape[2])).permute((1, 0, 2))
else:
assert False
else:
assert False
return out
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
global thread_resources_env, seq, tokenizer, network, params
default_params = {
@ -795,6 +818,53 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
"tokenizer": "gpt2",
}
params = kwargs
# Try to convert HF config.json to MTJ config
if hf_checkpoint:
spec_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "maps", vars.model_type + ".json")
if not os.path.isfile(spec_path):
raise NotImplementedError(f"Unsupported model type {repr(vars.model_type)}")
with open(spec_path) as f:
lazy_load_spec = json.load(f)
if "mtj_compat" in lazy_load_spec:
params["compat"] = lazy_load_spec["mtj_compat"]
if "mtj_pe" in lazy_load_spec:
params["pe"] = lazy_load_spec["mtj_pe"]
for k, v in lazy_load_spec.get("mtj_config_map", {}).items():
if type(v) is not list:
params[k] = params[v]
continue
for i in range(len(v)):
if i == len(v) - 1:
params[k] = v[i]
elif v[i] in params:
params[k] = params[v[i]]
break
params["n_vocab"] = params["vocab_size"]
if "activation_function" in params:
params["activation"] = params["activation_function"]
# Both the number of attention heads in the model and the embedding
# dimension of the model need to be divisible by the number of TPU cores
# that we use, and JAX also requires the number of TPU cores used to be
# an even number if we're using more than one core, so logically we try
# to pick the largest possible even number of TPU cores such that the
# number of attention heads and embedding dimension are both divisible
# by the number of TPU cores, and fall back to one core if an even
# number of TPU cores is not possible.
for c in (8, 6, 4, 2, 1):
if 0 == params["n_heads"] % c == params["d_model"] % c:
params["cores_per_replica"] = c
break
# The vocabulary size of the model also has to be divisible by the
# number of TPU cores, so we pad the vocabulary with the minimum
# possible number of dummy tokens such that it's divisible.
params["n_vocab_padding"] = -(params["n_vocab"] % -params["cores_per_replica"])
if "compat" in params:
default_params["compat"] = params["compat"]
if default_params["compat"] == "fairseq_lm":
@ -804,10 +874,11 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
params[param] = default_params[param]
# Load tokenizer
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"])
tokenizer = tokenizer_class.from_pretrained(params["tokenizer"])
if not hf_checkpoint:
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"])
tokenizer = tokenizer_class.from_pretrained(params["tokenizer"])
# Disable JAX warnings about these two functions having been renamed
jax.host_count = jax.process_count
@ -844,5 +915,147 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
path += "/"
network = PenalizingCausalTransformer(params, dematerialized=True)
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
if not hf_checkpoint:
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
return
# Convert from HF checkpoint
move_xmap = jax.experimental.maps.xmap(
fun=lambda x, _: to_bf16(x),
in_axes=(["shard", ...], ["batch", ...]),
out_axes=["shard", ...],
axis_resources={'shard': 'mp', 'batch': 'dp'}
)
model_spec = {}
for key, spec in lazy_load_spec.get("static_weights", {}).items():
if spec.get("mtj") is not None:
model_spec[key] = spec["mtj"].copy()
model_spec[key]["module"] = "causal_transformer_shard/~/" + model_spec[key]["module"]
for _key, spec in lazy_load_spec.get("layer_weights", {}).items():
for layer in range(params["layers"]):
if spec.get("mtj") is not None:
key = _key.format(layer=layer)
model_spec[key] = spec["mtj"].copy()
model_spec[key]["module"] = "causal_transformer_shard/~/" + model_spec[key]["module"].format(layer=layer)
import torch_lazy_loader
import torch
from tqdm import tqdm
def callback(model_dict, f, **_):
with zipfile.ZipFile(f, "r") as z:
try:
last_storage_key = None
f = None
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
# Some model weights are used by transformers but not by MTJ.
# We have to materialize these weights anyways because
# transformers will throw a tantrum otherwise. To attain
# the least possible memory usage, we create them as meta
# tensors, which don't take up any actual CPU or TPU memory.
if key not in model_spec:
model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].storage_type(0).dtype, device="meta")
continue
storage_key = model_dict[key].key
if storage_key != last_storage_key:
last_storage_key = storage_key
if isinstance(f, zipfile.ZipExtFile):
f.close()
f = z.open(f"archive/data/{storage_key}")
current_offset = f.tell()
if current_offset != model_dict[key].seek_offset:
f.seek(model_dict[key].seek_offset - current_offset, 1)
spec = model_spec[key]
transforms = set(spec.get("transforms", ()))
if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor):
error = f"Duplicate key {repr(key)}"
print("\n\nERROR: " + error, file=sys.stderr)
raise RuntimeError(error)
tensor = model_dict[key].materialize(f, map_location="cpu")
model_dict[key] = tensor.to("meta")
# MTJ requires certain mathematical operations to be performed
# on tensors in order for them to be in the correct format
if "divide_by_shards" in transforms:
tensor /= params["cores_per_replica"]
if "vocab_pad" in transforms:
tensor = torch.nn.functional.pad(tensor, (0, 0, 0, params["n_vocab_padding"]))
if "no_transpose" not in transforms:
tensor = tensor.T
tensor.unsqueeze_(0)
# Shard the tensor so that parts of the tensor can be used
# on different TPU cores
network.state["params"][spec["module"]][spec["param"]] = move_xmap(
jnp.array(
reshard_reverse(
tensor,
params["cores_per_replica"],
network.state["params"][spec["module"]][spec["param"]].shape,
),
dtype=jnp.bfloat16,
),
np.empty(params["cores_per_replica"]),
)
# Check for tensors that MTJ needs that were not provided in the
# HF model
for mk, mv in network.state["params"].items():
for pk, pv in mv.items():
if isinstance(pv, PlaceholderTensor):
# The transformers GPT-J models apparently do not
# have embedding bias, whereas MTJ GPT-J models do,
# so we have to supplement an embedding bias tensor
# by creating a tensor with the necessary shape, filled
# with zeros.
if mk == "causal_transformer_shard/~/embedding_shard/~/linear" and pk == "b":
mv[pk] = move_xmap(jnp.zeros(mv[pk].shape, dtype=jnp.bfloat16), np.empty(params["cores_per_replica"]))
else:
error = f"{mk} {pk} could not be found in the model checkpoint"
print("\n\nERROR: " + error, file=sys.stderr)
raise RuntimeError(error)
finally:
if isinstance(f, zipfile.ZipExtFile):
f.close()
if os.path.isdir(vars.model.replace('/', '_')):
import shutil
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
with torch_lazy_loader.use_lazy_torch_load(callback=callback, dematerialized_modules=True):
if(os.path.isdir(vars.custmodpth)):
try:
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache")
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache")
elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))):
try:
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
else:
try:
tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache")
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache")
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))