GPT-NeoX-20B support in Colab TPU instances
This commit is contained in:
parent
4892556059
commit
88f247d535
34
aiserver.py
34
aiserver.py
|
@ -317,7 +317,7 @@ def getmodelname():
|
||||||
if(args.configname):
|
if(args.configname):
|
||||||
modelname = args.configname
|
modelname = args.configname
|
||||||
return modelname
|
return modelname
|
||||||
if(vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ")):
|
if(vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
||||||
modelname = os.path.basename(os.path.normpath(vars.custmodpth))
|
modelname = os.path.basename(os.path.normpath(vars.custmodpth))
|
||||||
return modelname
|
return modelname
|
||||||
else:
|
else:
|
||||||
|
@ -699,7 +699,7 @@ def spRequest(filename):
|
||||||
vars.sp_length = tensor.shape[-2]
|
vars.sp_length = tensor.shape[-2]
|
||||||
vars.spmeta["n_tokens"] = vars.sp_length
|
vars.spmeta["n_tokens"] = vars.sp_length
|
||||||
|
|
||||||
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)):
|
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
||||||
rows = tensor.shape[0]
|
rows = tensor.shape[0]
|
||||||
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
|
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
|
||||||
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
|
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
|
||||||
|
@ -790,7 +790,7 @@ else:
|
||||||
getModelSelection(mainmenu)
|
getModelSelection(mainmenu)
|
||||||
|
|
||||||
# If transformers model was selected & GPU available, ask to use CPU or GPU
|
# If transformers model was selected & GPU available, ask to use CPU or GPU
|
||||||
if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
||||||
vars.allowsp = True
|
vars.allowsp = True
|
||||||
# Test for GPU support
|
# Test for GPU support
|
||||||
import torch
|
import torch
|
||||||
|
@ -830,7 +830,7 @@ if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMe
|
||||||
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
|
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
|
||||||
vars.model_type = "gpt_neo"
|
vars.model_type = "gpt_neo"
|
||||||
|
|
||||||
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
||||||
loadmodelsettings()
|
loadmodelsettings()
|
||||||
loadsettings()
|
loadsettings()
|
||||||
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
||||||
|
@ -1032,7 +1032,7 @@ socketio = SocketIO(app, async_method="eventlet")
|
||||||
print("{0}OK!{1}".format(colors.GREEN, colors.END))
|
print("{0}OK!{1}".format(colors.GREEN, colors.END))
|
||||||
|
|
||||||
# Start transformers and create pipeline
|
# Start transformers and create pipeline
|
||||||
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
||||||
if(not vars.noai):
|
if(not vars.noai):
|
||||||
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||||
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
|
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
|
||||||
|
@ -1050,7 +1050,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||||
if not vars.lazy_load:
|
if not vars.lazy_load:
|
||||||
return
|
return
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
if "breakmodel" in globals():
|
if "breakmodel" in globals():
|
||||||
gpu_blocks = breakmodel.gpu_blocks
|
gpu_blocks = breakmodel.gpu_blocks
|
||||||
|
@ -1553,9 +1553,9 @@ else:
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
||||||
loadsettings()
|
loadsettings()
|
||||||
# Load the TPU backend if requested
|
# Load the TPU backend if requested
|
||||||
elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"):
|
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
||||||
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||||
if vars.model == "TPUMeshTransformerGPTJ" and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)):
|
if vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)):
|
||||||
raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder")
|
raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder")
|
||||||
import tpu_mtj_backend
|
import tpu_mtj_backend
|
||||||
tpu_mtj_backend.vars = vars
|
tpu_mtj_backend.vars = vars
|
||||||
|
@ -1567,7 +1567,7 @@ else:
|
||||||
vars.allowsp = True
|
vars.allowsp = True
|
||||||
loadmodelsettings()
|
loadmodelsettings()
|
||||||
loadsettings()
|
loadsettings()
|
||||||
tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=vars.model != "TPUMeshTransformerGPTJ" and vars.use_colab_tpu, **vars.modelconfig)
|
tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and vars.use_colab_tpu, **vars.modelconfig)
|
||||||
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
||||||
tokenizer = tpu_mtj_backend.tokenizer
|
tokenizer = tpu_mtj_backend.tokenizer
|
||||||
else:
|
else:
|
||||||
|
@ -2098,7 +2098,7 @@ def lua_get_modeltype():
|
||||||
return "readonly"
|
return "readonly"
|
||||||
if(vars.model in ("Colab", "OAI", "InferKit")):
|
if(vars.model in ("Colab", "OAI", "InferKit")):
|
||||||
return "api"
|
return "api"
|
||||||
if(not vars.use_colab_tpu and vars.model not in ("TPUMeshTransformerGPTJ",) and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))):
|
if(not vars.use_colab_tpu and vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))):
|
||||||
hidden_size = get_hidden_size_from_model(model)
|
hidden_size = get_hidden_size_from_model(model)
|
||||||
if(vars.model in ("gpt2",) or (vars.model_type == "gpt2" and hidden_size == 768)):
|
if(vars.model in ("gpt2",) or (vars.model_type == "gpt2" and hidden_size == 768)):
|
||||||
return "gpt2"
|
return "gpt2"
|
||||||
|
@ -2127,7 +2127,7 @@ def lua_get_modelbackend():
|
||||||
return "readonly"
|
return "readonly"
|
||||||
if(vars.model in ("Colab", "OAI", "InferKit")):
|
if(vars.model in ("Colab", "OAI", "InferKit")):
|
||||||
return "api"
|
return "api"
|
||||||
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)):
|
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
||||||
return "mtj"
|
return "mtj"
|
||||||
return "transformers"
|
return "transformers"
|
||||||
|
|
||||||
|
@ -2136,7 +2136,7 @@ def lua_get_modelbackend():
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
@bridged_kwarg()
|
@bridged_kwarg()
|
||||||
def lua_is_custommodel():
|
def lua_is_custommodel():
|
||||||
return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ")
|
return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
#
|
#
|
||||||
|
@ -3074,22 +3074,22 @@ def calcsubmit(txt):
|
||||||
if(vars.model != "InferKit"):
|
if(vars.model != "InferKit"):
|
||||||
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions, submission=txt)
|
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions, submission=txt)
|
||||||
if(actionlen == 0):
|
if(actionlen == 0):
|
||||||
if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
|
if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
||||||
generate(subtxt, min, max, found_entries=found_entries)
|
generate(subtxt, min, max, found_entries=found_entries)
|
||||||
elif(vars.model == "Colab"):
|
elif(vars.model == "Colab"):
|
||||||
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||||
elif(vars.model == "OAI"):
|
elif(vars.model == "OAI"):
|
||||||
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||||
elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"):
|
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
||||||
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
||||||
else:
|
else:
|
||||||
if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
|
if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
||||||
generate(subtxt, min, max, found_entries=found_entries)
|
generate(subtxt, min, max, found_entries=found_entries)
|
||||||
elif(vars.model == "Colab"):
|
elif(vars.model == "Colab"):
|
||||||
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||||
elif(vars.model == "OAI"):
|
elif(vars.model == "OAI"):
|
||||||
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||||
elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"):
|
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
||||||
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
||||||
|
|
||||||
# For InferKit web API
|
# For InferKit web API
|
||||||
|
@ -5105,7 +5105,7 @@ if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
|
||||||
file.close()
|
file.close()
|
||||||
|
|
||||||
# Precompile TPU backend if required
|
# Precompile TPU backend if required
|
||||||
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)):
|
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
||||||
soft_tokens = tpumtjgetsofttokens()
|
soft_tokens = tpumtjgetsofttokens()
|
||||||
if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)):
|
if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)):
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
|
|
|
@ -46,6 +46,7 @@ import numpy as np
|
||||||
import optax
|
import optax
|
||||||
import haiku as hk
|
import haiku as hk
|
||||||
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
|
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
|
||||||
|
from tokenizers import Tokenizer
|
||||||
from mesh_transformer.checkpoint import read_ckpt_lowmem
|
from mesh_transformer.checkpoint import read_ckpt_lowmem
|
||||||
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard, PlaceholderTensor
|
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard, PlaceholderTensor
|
||||||
from mesh_transformer.util import to_bf16
|
from mesh_transformer.util import to_bf16
|
||||||
|
@ -800,6 +801,121 @@ def reshard_reverse(x, total_shards, old_shape):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def get_old_shape(t, total_shards, dim=2):
|
||||||
|
if len(t.shape) == 2:
|
||||||
|
shard_shape = t.shape
|
||||||
|
if dim == 1:
|
||||||
|
assert shard_shape[0] % total_shards == 0
|
||||||
|
return (shard_shape[0] // total_shards, shard_shape[1])
|
||||||
|
elif dim == 2:
|
||||||
|
assert shard_shape[1] % total_shards == 0
|
||||||
|
return (shard_shape[0], shard_shape[1] // total_shards)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported dim {dim}")
|
||||||
|
if len(t.shape) == 1:
|
||||||
|
assert t.shape[0] % total_shards == 0
|
||||||
|
return (t.shape[0] // total_shards,)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported shape {t.shape}")
|
||||||
|
|
||||||
|
|
||||||
|
def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
||||||
|
assert config["cores_per_replica"] % checkpoint_shards == 0
|
||||||
|
output_shards = config["cores_per_replica"] // checkpoint_shards
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
move_xmap = jax.experimental.maps.xmap(
|
||||||
|
fun=lambda x, _: to_bf16(x),
|
||||||
|
in_axes=(["shard", ...], ["batch", ...]),
|
||||||
|
out_axes=["shard", ...],
|
||||||
|
axis_resources={'shard': 'mp', 'batch': 'dp'}
|
||||||
|
)
|
||||||
|
|
||||||
|
path_template = os.path.join(path, "layer_{layer:02d}-model_{shard:02d}-model_states.pt")
|
||||||
|
|
||||||
|
static_mapping = {
|
||||||
|
"word_embeddings.weight": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1},
|
||||||
|
"final_linear.weight": {"module": "projection_shard/~/linear", "param": "w", "axis": 2},
|
||||||
|
"norm.weight": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale", "axis": None},
|
||||||
|
"norm.bias": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset", "axis": None},
|
||||||
|
}
|
||||||
|
|
||||||
|
layer_mapping = {
|
||||||
|
"attention.query_key_value.weight": {"module": "combined_qkv", "param": "w", "axis": 2},
|
||||||
|
"attention.query_key_value.bias": {"module": "combined_qkv", "param": "b", "axis": 1},
|
||||||
|
"attention.dense.weight": {"module": "linear_3", "param": "w", "axis": 1},
|
||||||
|
"attention.dense.bias": {"module": "linear_3", "param": "b", "axis": None},
|
||||||
|
"mlp.dense_h_to_4h.weight": {"module": "linear_4", "param": "w", "axis": 2},
|
||||||
|
"mlp.dense_h_to_4h.bias": {"module": "linear_4", "param": "b", "axis": 1},
|
||||||
|
"mlp.dense_4h_to_h.weight": {"module": "linear_5", "param": "w", "axis": 1},
|
||||||
|
"mlp.dense_4h_to_h.bias": {"module": "linear_5", "param": "b", "axis": None},
|
||||||
|
"input_layernorm.weight": {"module": "replicated_layer_norm", "param": "scale", "axis": None},
|
||||||
|
"input_layernorm.bias": {"module": "replicated_layer_norm", "param": "offset", "axis": None},
|
||||||
|
"post_attention_layernorm.weight": {"module": "replicated_layer_norm_1", "param": "scale", "axis": None},
|
||||||
|
"post_attention_layernorm.bias": {"module": "replicated_layer_norm_1", "param": "offset", "axis": None},
|
||||||
|
}
|
||||||
|
|
||||||
|
tqdm_length = len(static_mapping) + config["layers"]*len(layer_mapping)
|
||||||
|
bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint")
|
||||||
|
|
||||||
|
for checkpoint_layer in range(config["layers"] + 5):
|
||||||
|
if checkpoint_layer in (1, config["layers"] + 2):
|
||||||
|
continue
|
||||||
|
layer = checkpoint_layer - 2
|
||||||
|
shards = []
|
||||||
|
for checkpoint_shard in range(checkpoint_shards):
|
||||||
|
shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu"))
|
||||||
|
for key in shards[0]:
|
||||||
|
if key == "attention.rotary_emb.inv_freq":
|
||||||
|
continue
|
||||||
|
elif key in static_mapping:
|
||||||
|
target_module = "causal_transformer_shard/~/" + static_mapping[key]["module"]
|
||||||
|
target_param = static_mapping[key]["param"]
|
||||||
|
target_axis = static_mapping[key]["axis"]
|
||||||
|
elif key in layer_mapping:
|
||||||
|
target_module = f"causal_transformer_shard/~/layer_{layer}/~/" + layer_mapping[key]["module"]
|
||||||
|
target_param = layer_mapping[key]["param"]
|
||||||
|
target_axis = layer_mapping[key]["axis"]
|
||||||
|
else:
|
||||||
|
error = f"{repr(key)} not found in mapping"
|
||||||
|
print("\n\nERROR: ", error, file=sys.stderr)
|
||||||
|
raise RuntimeError(error)
|
||||||
|
original_shape = shards[0][key].shape
|
||||||
|
for checkpoint_shard in range(checkpoint_shards):
|
||||||
|
if key in ("attention.dense.bias", "mlp.dense_4h_to_h.bias"):
|
||||||
|
shards[checkpoint_shard][key] /= output_shards
|
||||||
|
if key != "word_embeddings.weight":
|
||||||
|
shards[checkpoint_shard][key] = shards[checkpoint_shard][key].T
|
||||||
|
tensor = shards[checkpoint_shard][key]
|
||||||
|
if target_axis is not None:
|
||||||
|
target_shape = (output_shards,) + get_old_shape(tensor, total_shards=output_shards, dim=target_axis)
|
||||||
|
else:
|
||||||
|
target_shape = (output_shards, tensor.shape[0])
|
||||||
|
shards[checkpoint_shard][key] = reshard_reverse(tensor.unsqueeze_(0), output_shards, target_shape)
|
||||||
|
#print(key, ":", original_shape, "->", shards[0][key].shape)
|
||||||
|
tensor = torch.cat([shards[s][key] for s in range(checkpoint_shards)], dim=0)
|
||||||
|
target_shape = state["params"][target_module][target_param].shape
|
||||||
|
if tensor.shape != target_shape:
|
||||||
|
error = f"Weight {repr(key)} has shape {tensor.shape} in checkpoint but shape {target_shape} was requested by MTJ for {target_module} {target_param}"
|
||||||
|
print("\n\nERROR: ", error, file=sys.stderr)
|
||||||
|
raise RuntimeError(error)
|
||||||
|
if tensor.dtype is torch.float16 or tensor.dtype is torch.float32:
|
||||||
|
tensor = tensor.bfloat16()
|
||||||
|
state["params"][target_module][target_param] = move_xmap(
|
||||||
|
jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(tensor)).copy(),
|
||||||
|
np.zeros(config["cores_per_replica"]),
|
||||||
|
)
|
||||||
|
bar.update(1)
|
||||||
|
for mk, mv in state["params"].items():
|
||||||
|
for pk, pv in mv.items():
|
||||||
|
if isinstance(pv, PlaceholderTensor):
|
||||||
|
error = f"{mk} {pk} could not be found in the model checkpoint"
|
||||||
|
print("\n\nERROR: " + error, file=sys.stderr)
|
||||||
|
raise RuntimeError(error)
|
||||||
|
|
||||||
|
|
||||||
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
|
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
|
||||||
global thread_resources_env, seq, tokenizer, network, params
|
global thread_resources_env, seq, tokenizer, network, params
|
||||||
|
|
||||||
|
@ -820,6 +936,23 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||||
}
|
}
|
||||||
params = kwargs
|
params = kwargs
|
||||||
|
|
||||||
|
if vars.model == "TPUMeshTransformerGPTNeoX":
|
||||||
|
default_params = {
|
||||||
|
"compat": "neox",
|
||||||
|
"layers": 44,
|
||||||
|
"d_model": 6144,
|
||||||
|
"n_heads": 64,
|
||||||
|
"n_vocab": 50432,
|
||||||
|
"n_vocab_padding": 0,
|
||||||
|
"norm": "doublelayernorm",
|
||||||
|
"pe": "rotary",
|
||||||
|
"pe_rotary_dims": 24,
|
||||||
|
"seq": 2048,
|
||||||
|
"cores_per_replica": 8,
|
||||||
|
"tokenizer_class": "GPT2TokenizerFast",
|
||||||
|
"tokenizer": "gpt2",
|
||||||
|
}
|
||||||
|
|
||||||
# Try to convert HF config.json to MTJ config
|
# Try to convert HF config.json to MTJ config
|
||||||
if hf_checkpoint:
|
if hf_checkpoint:
|
||||||
spec_path = os.path.join("maps", vars.model_type + ".json")
|
spec_path = os.path.join("maps", vars.model_type + ".json")
|
||||||
|
@ -875,7 +1008,14 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||||
params[param] = default_params[param]
|
params[param] = default_params[param]
|
||||||
|
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
if not hf_checkpoint:
|
if vars.model == "TPUMeshTransformerGPTNeoX":
|
||||||
|
tokenizer = Tokenizer.from_file(os.path.join(path, "20B_tokenizer.json"))
|
||||||
|
def new_encode(old_encode):
|
||||||
|
def encode(s, *args, **kwargs):
|
||||||
|
return old_encode(s).ids
|
||||||
|
return encode
|
||||||
|
tokenizer.encode = new_encode(tokenizer.encode)
|
||||||
|
elif not hf_checkpoint:
|
||||||
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
|
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
|
||||||
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
|
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
|
||||||
tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"])
|
tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"])
|
||||||
|
@ -917,9 +1057,13 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||||
|
|
||||||
network = PenalizingCausalTransformer(params, dematerialized=True)
|
network = PenalizingCausalTransformer(params, dematerialized=True)
|
||||||
|
|
||||||
if not hf_checkpoint:
|
if not hf_checkpoint and vars.model != "TPUMeshTransformerGPTNeoX":
|
||||||
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
|
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
|
||||||
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
|
#network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
|
||||||
|
return
|
||||||
|
|
||||||
|
if vars.model == "TPUMeshTransformerGPTNeoX":
|
||||||
|
read_neox_checkpoint(network.state, path, params)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Convert from HF checkpoint
|
# Convert from HF checkpoint
|
||||||
|
@ -945,7 +1089,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||||
|
|
||||||
import torch_lazy_loader
|
import torch_lazy_loader
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
def callback(model_dict, f, **_):
|
def callback(model_dict, f, **_):
|
||||||
with zipfile.ZipFile(f, "r") as z:
|
with zipfile.ZipFile(f, "r") as z:
|
||||||
|
@ -1069,4 +1213,4 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache")
|
model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache")
|
||||||
|
|
||||||
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
|
#network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
|
||||||
|
|
Loading…
Reference in New Issue