Merge pull request #101 from VE-FORBRYDERNE/neox

GPT-NeoX-20B support in Colab TPU instances
This commit is contained in:
henk717 2022-03-19 09:56:15 +01:00 committed by GitHub
commit a7f652f293
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 174 additions and 22 deletions

View File

@ -196,6 +196,7 @@ class vars:
corescript = "default.lua" # Filename of corescript to load
# badwords = [] # Array of str/chr values that should be removed from output
badwordsids = [[13460], [6880], [50256], [42496], [4613], [17414], [22039], [16410], [27], [29], [38430], [37922], [15913], [24618], [28725], [58], [47175], [36937], [26700], [12878], [16471], [37981], [5218], [29795], [13412], [45160], [3693], [49778], [4211], [20598], [36475], [33409], [44167], [32406], [29847], [29342], [42669], [685], [25787], [7359], [3784], [5320], [33994], [33490], [34516], [43734], [17635], [24293], [9959], [23785], [21737], [28401], [18161], [26358], [32509], [1279], [38155], [18189], [26894], [6927], [14610], [23834], [11037], [14631], [26933], [46904], [22330], [25915], [47934], [38214], [1875], [14692], [41832], [13163], [25970], [29565], [44926], [19841], [37250], [49029], [9609], [44438], [16791], [17816], [30109], [41888], [47527], [42924], [23984], [49074], [33717], [31161], [49082], [30138], [31175], [12240], [14804], [7131], [26076], [33250], [3556], [38381], [36338], [32756], [46581], [17912], [49146]] # Tokenized array of badwords used to prevent AI artifacting
badwordsids_neox = [[0], [1], [44162], [9502], [12520], [31841], [36320], [49824], [34417], [6038], [34494], [24815], [26635], [24345], [3455], [28905], [44270], [17278], [32666], [46880], [7086], [43189], [37322], [17778], [20879], [49821], [3138], [14490], [4681], [21391], [26786], [43134], [9336], [683], [48074], [41256], [19181], [29650], [28532], [36487], [45114], [46275], [16445], [15104], [11337], [1168], [5647], [29], [27482], [44965], [43782], [31011], [42944], [47389], [6334], [17548], [38329], [32044], [35487], [2239], [34761], [7444], [1084], [12399], [18990], [17636], [39083], [1184], [35830], [28365], [16731], [43467], [47744], [1138], [16079], [40116], [45564], [18297], [42368], [5456], [18022], [42696], [34476], [23505], [23741], [39334], [37944], [45382], [38709], [33440], [26077], [43600], [34418], [36033], [6660], [48167], [48471], [15775], [19884], [41533], [1008], [31053], [36692], [46576], [20095], [20629], [31759], [46410], [41000], [13488], [30952], [39258], [16160], [27655], [22367], [42767], [43736], [49694], [13811], [12004], [46768], [6257], [37471], [5264], [44153], [33805], [20977], [21083], [25416], [14277], [31096], [42041], [18331], [33376], [22372], [46294], [28379], [38475], [1656], [5204], [27075], [50001], [16616], [11396], [7748], [48744], [35402], [28120], [41512], [4207], [43144], [14767], [15640], [16595], [41305], [44479], [38958], [18474], [22734], [30522], [46267], [60], [13976], [31830], [48701], [39822], [9014], [21966], [31422], [28052], [34607], [2479], [3851], [32214], [44082], [45507], [3001], [34368], [34758], [13380], [38363], [4299], [46802], [30996], [12630], [49236], [7082], [8795], [5218], [44740], [9686], [9983], [45301], [27114], [40125], [1570], [26997], [544], [5290], [49193], [23781], [14193], [40000], [2947], [43781], [9102], [48064], [42274], [18772], [49384], [9884], [45635], [43521], [31258], [32056], [47686], [21760], [13143], [10148], [26119], [44308], [31379], [36399], [23983], [46694], [36134], [8562], [12977], [35117], [28591], [49021], [47093], [28653], [29013], [46468], [8605], [7254], [25896], [5032], [8168], [36893], [38270], [20499], [27501], [34419], [29547], [28571], [36586], [20871], [30537], [26842], [21375], [31148], [27618], [33094], [3291], [31789], [28391], [870], [9793], [41361], [47916], [27468], [43856], [8850], [35237], [15707], [47552], [2730], [41449], [45488], [3073], [49806], [21938], [24430], [22747], [20924], [46145], [20481], [20197], [8239], [28231], [17987], [42804], [47269], [29972], [49884], [21382], [46295], [36676], [34616], [3921], [26991], [27720], [46265], [654], [9855], [40354], [5291], [34904], [44342], [2470], [14598], [880], [19282], [2498], [24237], [21431], [16369], [8994], [44524], [45662], [13663], [37077], [1447], [37786], [30863], [42854], [1019], [20322], [4398], [12159], [44072], [48664], [31547], [18736], [9259], [31], [16354], [21810], [4357], [37982], [5064], [2033], [32871], [47446], [62], [22158], [37387], [8743], [47007], [17981], [11049], [4622], [37916], [36786], [35138], [29925], [14157], [18095], [27829], [1181], [22226], [5709], [4725], [30189], [37014], [1254], [11380], [42989], [696], [24576], [39487], [30119], [1092], [8088], [2194], [9899], [14412], [21828], [3725], [13544], [5180], [44679], [34398], [3891], [28739], [14219], [37594], [49550], [11326], [6904], [17266], [5749], [10174], [23405], [9955], [38271], [41018], [13011], [48392], [36784], [24254], [21687], [23734], [5413], [41447], [45472], [10122], [17555], [15830], [47384], [12084], [31350], [47940], [11661], [27988], [45443], [905], [49651], [16614], [34993], [6781], [30803], [35869], [8001], [41604], [28118], [46462], [46762], [16262], [17281], [5774], [10943], [5013], [18257], [6750], [4713], [3951], [11899], [38791], [16943], [37596], [9318], [18413], [40473], [13208], [16375]]
deletewi = None # Temporary storage for UID to delete
wirmvwhtsp = False # Whether to remove leading whitespace from WI entries
widepth = 3 # How many historical actions to scan for WI hits
@ -317,7 +318,7 @@ def getmodelname():
if(args.configname):
modelname = args.configname
return modelname
if(vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ")):
if(vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
modelname = os.path.basename(os.path.normpath(vars.custmodpth))
return modelname
else:
@ -699,7 +700,7 @@ def spRequest(filename):
vars.sp_length = tensor.shape[-2]
vars.spmeta["n_tokens"] = vars.sp_length
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)):
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
rows = tensor.shape[0]
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
@ -790,7 +791,7 @@ else:
getModelSelection(mainmenu)
# If transformers model was selected & GPU available, ask to use CPU or GPU
if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]):
if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
vars.allowsp = True
# Test for GPU support
import torch
@ -830,7 +831,7 @@ if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMe
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
vars.model_type = "gpt_neo"
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]):
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
loadmodelsettings()
loadsettings()
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
@ -1032,7 +1033,7 @@ socketio = SocketIO(app, async_method="eventlet")
print("{0}OK!{1}".format(colors.GREEN, colors.END))
# Start transformers and create pipeline
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]):
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
if(not vars.noai):
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
@ -1050,7 +1051,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
if not vars.lazy_load:
return
from tqdm import tqdm
from tqdm.auto import tqdm
if "breakmodel" in globals():
gpu_blocks = breakmodel.gpu_blocks
@ -1380,6 +1381,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
if os.path.isdir(vars.model.replace('/', '_')):
import shutil
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
print("\n", flush=True)
with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(model_config.num_layers if hasattr(model_config, "num_layers") else model_config.n_layer) if vars.lazy_load else None, dematerialized_modules=True):
if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
lowmem = {}
@ -1553,11 +1555,15 @@ else:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
loadsettings()
# Load the TPU backend if requested
elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"):
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
if(vars.model == "TPUMeshTransformerGPTNeoX"):
vars.badwordsids = vars.badwordsids_neox
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
if vars.model == "TPUMeshTransformerGPTJ" and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)):
if vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)):
raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder")
import tpu_mtj_backend
if(vars.model == "TPUMeshTransformerGPTNeoX"):
tpu_mtj_backend.pad_token_id = 1
tpu_mtj_backend.vars = vars
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback
@ -1567,7 +1573,7 @@ else:
vars.allowsp = True
loadmodelsettings()
loadsettings()
tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=vars.model != "TPUMeshTransformerGPTJ" and vars.use_colab_tpu, **vars.modelconfig)
tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and vars.use_colab_tpu, **vars.modelconfig)
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
tokenizer = tpu_mtj_backend.tokenizer
else:
@ -2098,7 +2104,7 @@ def lua_get_modeltype():
return "readonly"
if(vars.model in ("Colab", "OAI", "InferKit")):
return "api"
if(not vars.use_colab_tpu and vars.model not in ("TPUMeshTransformerGPTJ",) and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))):
if(not vars.use_colab_tpu and vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))):
hidden_size = get_hidden_size_from_model(model)
if(vars.model in ("gpt2",) or (vars.model_type == "gpt2" and hidden_size == 768)):
return "gpt2"
@ -2127,7 +2133,7 @@ def lua_get_modelbackend():
return "readonly"
if(vars.model in ("Colab", "OAI", "InferKit")):
return "api"
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)):
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
return "mtj"
return "transformers"
@ -2136,7 +2142,7 @@ def lua_get_modelbackend():
#==================================================================#
@bridged_kwarg()
def lua_is_custommodel():
return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ")
return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")
#==================================================================#
#
@ -3074,22 +3080,22 @@ def calcsubmit(txt):
if(vars.model != "InferKit"):
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions, submission=txt)
if(actionlen == 0):
if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"):
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.model == "OAI"):
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"):
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
else:
if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"):
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.model == "OAI"):
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"):
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
# For InferKit web API
@ -5105,7 +5111,7 @@ if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
file.close()
# Precompile TPU backend if required
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)):
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
soft_tokens = tpumtjgetsofttokens()
if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)):
threading.Thread(

View File

@ -46,6 +46,7 @@ import numpy as np
import optax
import haiku as hk
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
from tokenizers import Tokenizer
from mesh_transformer.checkpoint import read_ckpt_lowmem
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard, PlaceholderTensor
from mesh_transformer.util import to_bf16
@ -800,6 +801,121 @@ def reshard_reverse(x, total_shards, old_shape):
return out
def get_old_shape(t, total_shards, dim=2):
if len(t.shape) == 2:
shard_shape = t.shape
if dim == 1:
assert shard_shape[0] % total_shards == 0
return (shard_shape[0] // total_shards, shard_shape[1])
elif dim == 2:
assert shard_shape[1] % total_shards == 0
return (shard_shape[0], shard_shape[1] // total_shards)
else:
raise ValueError(f"Unsupported dim {dim}")
if len(t.shape) == 1:
assert t.shape[0] % total_shards == 0
return (t.shape[0] // total_shards,)
else:
raise ValueError(f"Unsupported shape {t.shape}")
def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
assert config["cores_per_replica"] % checkpoint_shards == 0
output_shards = config["cores_per_replica"] // checkpoint_shards
import torch
from tqdm.auto import tqdm
move_xmap = jax.experimental.maps.xmap(
fun=lambda x, _: to_bf16(x),
in_axes=(["shard", ...], ["batch", ...]),
out_axes=["shard", ...],
axis_resources={'shard': 'mp', 'batch': 'dp'}
)
path_template = os.path.join(path, "layer_{layer:02d}-model_{shard:02d}-model_states.pt")
static_mapping = {
"word_embeddings.weight": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1},
"final_linear.weight": {"module": "projection_shard/~/linear", "param": "w", "axis": 2},
"norm.weight": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale", "axis": None},
"norm.bias": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset", "axis": None},
}
layer_mapping = {
"attention.query_key_value.weight": {"module": "combined_qkv", "param": "w", "axis": 2},
"attention.query_key_value.bias": {"module": "combined_qkv", "param": "b", "axis": 1},
"attention.dense.weight": {"module": "linear_3", "param": "w", "axis": 1},
"attention.dense.bias": {"module": "linear_3", "param": "b", "axis": None},
"mlp.dense_h_to_4h.weight": {"module": "linear_4", "param": "w", "axis": 2},
"mlp.dense_h_to_4h.bias": {"module": "linear_4", "param": "b", "axis": 1},
"mlp.dense_4h_to_h.weight": {"module": "linear_5", "param": "w", "axis": 1},
"mlp.dense_4h_to_h.bias": {"module": "linear_5", "param": "b", "axis": None},
"input_layernorm.weight": {"module": "replicated_layer_norm", "param": "scale", "axis": None},
"input_layernorm.bias": {"module": "replicated_layer_norm", "param": "offset", "axis": None},
"post_attention_layernorm.weight": {"module": "replicated_layer_norm_1", "param": "scale", "axis": None},
"post_attention_layernorm.bias": {"module": "replicated_layer_norm_1", "param": "offset", "axis": None},
}
tqdm_length = len(static_mapping) + config["layers"]*len(layer_mapping)
bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint")
for checkpoint_layer in range(config["layers"] + 5):
if checkpoint_layer in (1, config["layers"] + 2):
continue
layer = checkpoint_layer - 2
shards = []
for checkpoint_shard in range(checkpoint_shards):
shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu"))
for key in shards[0]:
if key == "attention.rotary_emb.inv_freq":
continue
elif key in static_mapping:
target_module = "causal_transformer_shard/~/" + static_mapping[key]["module"]
target_param = static_mapping[key]["param"]
target_axis = static_mapping[key]["axis"]
elif key in layer_mapping:
target_module = f"causal_transformer_shard/~/layer_{layer}/~/" + layer_mapping[key]["module"]
target_param = layer_mapping[key]["param"]
target_axis = layer_mapping[key]["axis"]
else:
error = f"{repr(key)} not found in mapping"
print("\n\nERROR: ", error, file=sys.stderr)
raise RuntimeError(error)
original_shape = shards[0][key].shape
for checkpoint_shard in range(checkpoint_shards):
if key in ("attention.dense.bias", "mlp.dense_4h_to_h.bias"):
shards[checkpoint_shard][key] /= config["cores_per_replica"]
if key != "word_embeddings.weight" and shards[checkpoint_shard][key].ndim == 2:
shards[checkpoint_shard][key] = shards[checkpoint_shard][key].T
tensor = shards[checkpoint_shard][key]
if target_axis is not None:
target_shape = (output_shards,) + get_old_shape(tensor, total_shards=output_shards, dim=target_axis)
else:
target_shape = (output_shards, tensor.shape[0])
shards[checkpoint_shard][key] = reshard_reverse(tensor.unsqueeze_(0), output_shards, target_shape)
#print(key, ":", original_shape, "->", shards[0][key].shape)
tensor = torch.cat([shards[s][key] for s in range(checkpoint_shards)], dim=0)
target_shape = state["params"][target_module][target_param].shape
if tensor.shape != target_shape:
error = f"Weight {repr(key)} has shape {tensor.shape} in checkpoint but shape {target_shape} was requested by MTJ for {target_module} {target_param}"
print("\n\nERROR: ", error, file=sys.stderr)
raise RuntimeError(error)
if tensor.dtype is torch.float16 or tensor.dtype is torch.float32:
tensor = tensor.bfloat16()
state["params"][target_module][target_param] = move_xmap(
jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(tensor)).copy(),
np.zeros(config["cores_per_replica"]),
)
bar.update(1)
for mk, mv in state["params"].items():
for pk, pv in mv.items():
if isinstance(pv, PlaceholderTensor):
error = f"{mk} {pk} could not be found in the model checkpoint"
print("\n\nERROR: " + error, file=sys.stderr)
raise RuntimeError(error)
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
global thread_resources_env, seq, tokenizer, network, params
@ -820,6 +936,23 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
}
params = kwargs
if vars.model == "TPUMeshTransformerGPTNeoX":
default_params = {
"compat": "neox",
"layers": 44,
"d_model": 6144,
"n_heads": 64,
"n_vocab": 50432,
"n_vocab_padding": 0,
"norm": "doublelayernorm",
"pe": "neox_rotary",
"pe_rotary_dims": 24,
"seq": 2048,
"cores_per_replica": 8,
"tokenizer_class": "GPT2TokenizerFast",
"tokenizer": "gpt2",
}
# Try to convert HF config.json to MTJ config
if hf_checkpoint:
spec_path = os.path.join("maps", vars.model_type + ".json")
@ -875,7 +1008,14 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
params[param] = default_params[param]
# Load tokenizer
if not hf_checkpoint:
if vars.model == "TPUMeshTransformerGPTNeoX":
tokenizer = Tokenizer.from_file(os.path.join(path, "20B_tokenizer.json"))
def new_encode(old_encode):
def encode(s, *args, **kwargs):
return old_encode(s).ids
return encode
tokenizer.encode = new_encode(tokenizer.encode)
elif not hf_checkpoint:
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"])
@ -917,9 +1057,14 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
network = PenalizingCausalTransformer(params, dematerialized=True)
if not hf_checkpoint:
if not hf_checkpoint and vars.model != "TPUMeshTransformerGPTNeoX":
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
#network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
return
if vars.model == "TPUMeshTransformerGPTNeoX":
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
read_neox_checkpoint(network.state, path, params)
return
# Convert from HF checkpoint
@ -945,7 +1090,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
import torch_lazy_loader
import torch
from tqdm import tqdm
from tqdm.auto import tqdm
def callback(model_dict, f, **_):
with zipfile.ZipFile(f, "r") as z:
@ -1031,6 +1176,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
if os.path.isdir(vars.model.replace('/', '_')):
import shutil
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
print("\n", flush=True)
with torch_lazy_loader.use_lazy_torch_load(callback=callback, dematerialized_modules=True):
if(os.path.isdir(vars.custmodpth)):
try:
@ -1069,4 +1215,4 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
except Exception as e:
model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache")
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
#network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))