Merge pull request #101 from VE-FORBRYDERNE/neox

GPT-NeoX-20B support in Colab TPU instances
This commit is contained in:
henk717 2022-03-19 09:56:15 +01:00 committed by GitHub
commit a7f652f293
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 174 additions and 22 deletions

View File

@ -196,6 +196,7 @@ class vars:
corescript = "default.lua" # Filename of corescript to load corescript = "default.lua" # Filename of corescript to load
# badwords = [] # Array of str/chr values that should be removed from output # badwords = [] # Array of str/chr values that should be removed from output
badwordsids = [[13460], [6880], [50256], [42496], [4613], [17414], [22039], [16410], [27], [29], [38430], [37922], [15913], [24618], [28725], [58], [47175], [36937], [26700], [12878], [16471], [37981], [5218], [29795], [13412], [45160], [3693], [49778], [4211], [20598], [36475], [33409], [44167], [32406], [29847], [29342], [42669], [685], [25787], [7359], [3784], [5320], [33994], [33490], [34516], [43734], [17635], [24293], [9959], [23785], [21737], [28401], [18161], [26358], [32509], [1279], [38155], [18189], [26894], [6927], [14610], [23834], [11037], [14631], [26933], [46904], [22330], [25915], [47934], [38214], [1875], [14692], [41832], [13163], [25970], [29565], [44926], [19841], [37250], [49029], [9609], [44438], [16791], [17816], [30109], [41888], [47527], [42924], [23984], [49074], [33717], [31161], [49082], [30138], [31175], [12240], [14804], [7131], [26076], [33250], [3556], [38381], [36338], [32756], [46581], [17912], [49146]] # Tokenized array of badwords used to prevent AI artifacting badwordsids = [[13460], [6880], [50256], [42496], [4613], [17414], [22039], [16410], [27], [29], [38430], [37922], [15913], [24618], [28725], [58], [47175], [36937], [26700], [12878], [16471], [37981], [5218], [29795], [13412], [45160], [3693], [49778], [4211], [20598], [36475], [33409], [44167], [32406], [29847], [29342], [42669], [685], [25787], [7359], [3784], [5320], [33994], [33490], [34516], [43734], [17635], [24293], [9959], [23785], [21737], [28401], [18161], [26358], [32509], [1279], [38155], [18189], [26894], [6927], [14610], [23834], [11037], [14631], [26933], [46904], [22330], [25915], [47934], [38214], [1875], [14692], [41832], [13163], [25970], [29565], [44926], [19841], [37250], [49029], [9609], [44438], [16791], [17816], [30109], [41888], [47527], [42924], [23984], [49074], [33717], [31161], [49082], [30138], [31175], [12240], [14804], [7131], [26076], [33250], [3556], [38381], [36338], [32756], [46581], [17912], [49146]] # Tokenized array of badwords used to prevent AI artifacting
badwordsids_neox = [[0], [1], [44162], [9502], [12520], [31841], [36320], [49824], [34417], [6038], [34494], [24815], [26635], [24345], [3455], [28905], [44270], [17278], [32666], [46880], [7086], [43189], [37322], [17778], [20879], [49821], [3138], [14490], [4681], [21391], [26786], [43134], [9336], [683], [48074], [41256], [19181], [29650], [28532], [36487], [45114], [46275], [16445], [15104], [11337], [1168], [5647], [29], [27482], [44965], [43782], [31011], [42944], [47389], [6334], [17548], [38329], [32044], [35487], [2239], [34761], [7444], [1084], [12399], [18990], [17636], [39083], [1184], [35830], [28365], [16731], [43467], [47744], [1138], [16079], [40116], [45564], [18297], [42368], [5456], [18022], [42696], [34476], [23505], [23741], [39334], [37944], [45382], [38709], [33440], [26077], [43600], [34418], [36033], [6660], [48167], [48471], [15775], [19884], [41533], [1008], [31053], [36692], [46576], [20095], [20629], [31759], [46410], [41000], [13488], [30952], [39258], [16160], [27655], [22367], [42767], [43736], [49694], [13811], [12004], [46768], [6257], [37471], [5264], [44153], [33805], [20977], [21083], [25416], [14277], [31096], [42041], [18331], [33376], [22372], [46294], [28379], [38475], [1656], [5204], [27075], [50001], [16616], [11396], [7748], [48744], [35402], [28120], [41512], [4207], [43144], [14767], [15640], [16595], [41305], [44479], [38958], [18474], [22734], [30522], [46267], [60], [13976], [31830], [48701], [39822], [9014], [21966], [31422], [28052], [34607], [2479], [3851], [32214], [44082], [45507], [3001], [34368], [34758], [13380], [38363], [4299], [46802], [30996], [12630], [49236], [7082], [8795], [5218], [44740], [9686], [9983], [45301], [27114], [40125], [1570], [26997], [544], [5290], [49193], [23781], [14193], [40000], [2947], [43781], [9102], [48064], [42274], [18772], [49384], [9884], [45635], [43521], [31258], [32056], [47686], [21760], [13143], [10148], [26119], [44308], [31379], [36399], [23983], [46694], [36134], [8562], [12977], [35117], [28591], [49021], [47093], [28653], [29013], [46468], [8605], [7254], [25896], [5032], [8168], [36893], [38270], [20499], [27501], [34419], [29547], [28571], [36586], [20871], [30537], [26842], [21375], [31148], [27618], [33094], [3291], [31789], [28391], [870], [9793], [41361], [47916], [27468], [43856], [8850], [35237], [15707], [47552], [2730], [41449], [45488], [3073], [49806], [21938], [24430], [22747], [20924], [46145], [20481], [20197], [8239], [28231], [17987], [42804], [47269], [29972], [49884], [21382], [46295], [36676], [34616], [3921], [26991], [27720], [46265], [654], [9855], [40354], [5291], [34904], [44342], [2470], [14598], [880], [19282], [2498], [24237], [21431], [16369], [8994], [44524], [45662], [13663], [37077], [1447], [37786], [30863], [42854], [1019], [20322], [4398], [12159], [44072], [48664], [31547], [18736], [9259], [31], [16354], [21810], [4357], [37982], [5064], [2033], [32871], [47446], [62], [22158], [37387], [8743], [47007], [17981], [11049], [4622], [37916], [36786], [35138], [29925], [14157], [18095], [27829], [1181], [22226], [5709], [4725], [30189], [37014], [1254], [11380], [42989], [696], [24576], [39487], [30119], [1092], [8088], [2194], [9899], [14412], [21828], [3725], [13544], [5180], [44679], [34398], [3891], [28739], [14219], [37594], [49550], [11326], [6904], [17266], [5749], [10174], [23405], [9955], [38271], [41018], [13011], [48392], [36784], [24254], [21687], [23734], [5413], [41447], [45472], [10122], [17555], [15830], [47384], [12084], [31350], [47940], [11661], [27988], [45443], [905], [49651], [16614], [34993], [6781], [30803], [35869], [8001], [41604], [28118], [46462], [46762], [16262], [17281], [5774], [10943], [5013], [18257], [6750], [4713], [3951], [11899], [38791], [16943], [37596], [9318], [18413], [40473], [13208], [16375]]
deletewi = None # Temporary storage for UID to delete deletewi = None # Temporary storage for UID to delete
wirmvwhtsp = False # Whether to remove leading whitespace from WI entries wirmvwhtsp = False # Whether to remove leading whitespace from WI entries
widepth = 3 # How many historical actions to scan for WI hits widepth = 3 # How many historical actions to scan for WI hits
@ -317,7 +318,7 @@ def getmodelname():
if(args.configname): if(args.configname):
modelname = args.configname modelname = args.configname
return modelname return modelname
if(vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ")): if(vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
modelname = os.path.basename(os.path.normpath(vars.custmodpth)) modelname = os.path.basename(os.path.normpath(vars.custmodpth))
return modelname return modelname
else: else:
@ -699,7 +700,7 @@ def spRequest(filename):
vars.sp_length = tensor.shape[-2] vars.sp_length = tensor.shape[-2]
vars.spmeta["n_tokens"] = vars.sp_length vars.spmeta["n_tokens"] = vars.sp_length
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
rows = tensor.shape[0] rows = tensor.shape[0]
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
@ -790,7 +791,7 @@ else:
getModelSelection(mainmenu) getModelSelection(mainmenu)
# If transformers model was selected & GPU available, ask to use CPU or GPU # If transformers model was selected & GPU available, ask to use CPU or GPU
if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]): if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
vars.allowsp = True vars.allowsp = True
# Test for GPU support # Test for GPU support
import torch import torch
@ -830,7 +831,7 @@ if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMe
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)") print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
vars.model_type = "gpt_neo" vars.model_type = "gpt_neo"
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]): if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
loadmodelsettings() loadmodelsettings()
loadsettings() loadsettings()
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="") print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
@ -1032,7 +1033,7 @@ socketio = SocketIO(app, async_method="eventlet")
print("{0}OK!{1}".format(colors.GREEN, colors.END)) print("{0}OK!{1}".format(colors.GREEN, colors.END))
# Start transformers and create pipeline # Start transformers and create pipeline
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]): if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
if(not vars.noai): if(not vars.noai):
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END)) print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
@ -1050,7 +1051,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
if not vars.lazy_load: if not vars.lazy_load:
return return
from tqdm import tqdm from tqdm.auto import tqdm
if "breakmodel" in globals(): if "breakmodel" in globals():
gpu_blocks = breakmodel.gpu_blocks gpu_blocks = breakmodel.gpu_blocks
@ -1380,6 +1381,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
if os.path.isdir(vars.model.replace('/', '_')): if os.path.isdir(vars.model.replace('/', '_')):
import shutil import shutil
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_'))) shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
print("\n", flush=True)
with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(model_config.num_layers if hasattr(model_config, "num_layers") else model_config.n_layer) if vars.lazy_load else None, dematerialized_modules=True): with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(model_config.num_layers if hasattr(model_config, "num_layers") else model_config.n_layer) if vars.lazy_load else None, dematerialized_modules=True):
if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
lowmem = {} lowmem = {}
@ -1553,11 +1555,15 @@ else:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/") tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
loadsettings() loadsettings()
# Load the TPU backend if requested # Load the TPU backend if requested
elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
if(vars.model == "TPUMeshTransformerGPTNeoX"):
vars.badwordsids = vars.badwordsids_neox
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END)) print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
if vars.model == "TPUMeshTransformerGPTJ" and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)): if vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)):
raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder") raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder")
import tpu_mtj_backend import tpu_mtj_backend
if(vars.model == "TPUMeshTransformerGPTNeoX"):
tpu_mtj_backend.pad_token_id = 1
tpu_mtj_backend.vars = vars tpu_mtj_backend.vars = vars
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback
@ -1567,7 +1573,7 @@ else:
vars.allowsp = True vars.allowsp = True
loadmodelsettings() loadmodelsettings()
loadsettings() loadsettings()
tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=vars.model != "TPUMeshTransformerGPTJ" and vars.use_colab_tpu, **vars.modelconfig) tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and vars.use_colab_tpu, **vars.modelconfig)
vars.modeldim = int(tpu_mtj_backend.params["d_model"]) vars.modeldim = int(tpu_mtj_backend.params["d_model"])
tokenizer = tpu_mtj_backend.tokenizer tokenizer = tpu_mtj_backend.tokenizer
else: else:
@ -2098,7 +2104,7 @@ def lua_get_modeltype():
return "readonly" return "readonly"
if(vars.model in ("Colab", "OAI", "InferKit")): if(vars.model in ("Colab", "OAI", "InferKit")):
return "api" return "api"
if(not vars.use_colab_tpu and vars.model not in ("TPUMeshTransformerGPTJ",) and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))): if(not vars.use_colab_tpu and vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))):
hidden_size = get_hidden_size_from_model(model) hidden_size = get_hidden_size_from_model(model)
if(vars.model in ("gpt2",) or (vars.model_type == "gpt2" and hidden_size == 768)): if(vars.model in ("gpt2",) or (vars.model_type == "gpt2" and hidden_size == 768)):
return "gpt2" return "gpt2"
@ -2127,7 +2133,7 @@ def lua_get_modelbackend():
return "readonly" return "readonly"
if(vars.model in ("Colab", "OAI", "InferKit")): if(vars.model in ("Colab", "OAI", "InferKit")):
return "api" return "api"
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
return "mtj" return "mtj"
return "transformers" return "transformers"
@ -2136,7 +2142,7 @@ def lua_get_modelbackend():
#==================================================================# #==================================================================#
@bridged_kwarg() @bridged_kwarg()
def lua_is_custommodel(): def lua_is_custommodel():
return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ") return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")
#==================================================================# #==================================================================#
# #
@ -3074,22 +3080,22 @@ def calcsubmit(txt):
if(vars.model != "InferKit"): if(vars.model != "InferKit"):
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions, submission=txt) subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions, submission=txt)
if(actionlen == 0): if(actionlen == 0):
if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
generate(subtxt, min, max, found_entries=found_entries) generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"): elif(vars.model == "Colab"):
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.model == "OAI"): elif(vars.model == "OAI"):
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
tpumtjgenerate(subtxt, min, max, found_entries=found_entries) tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
else: else:
if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
generate(subtxt, min, max, found_entries=found_entries) generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"): elif(vars.model == "Colab"):
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.model == "OAI"): elif(vars.model == "OAI"):
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
tpumtjgenerate(subtxt, min, max, found_entries=found_entries) tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
# For InferKit web API # For InferKit web API
@ -5105,7 +5111,7 @@ if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
file.close() file.close()
# Precompile TPU backend if required # Precompile TPU backend if required
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
soft_tokens = tpumtjgetsofttokens() soft_tokens = tpumtjgetsofttokens()
if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)): if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)):
threading.Thread( threading.Thread(

View File

@ -46,6 +46,7 @@ import numpy as np
import optax import optax
import haiku as hk import haiku as hk
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
from tokenizers import Tokenizer
from mesh_transformer.checkpoint import read_ckpt_lowmem from mesh_transformer.checkpoint import read_ckpt_lowmem
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard, PlaceholderTensor from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard, PlaceholderTensor
from mesh_transformer.util import to_bf16 from mesh_transformer.util import to_bf16
@ -800,6 +801,121 @@ def reshard_reverse(x, total_shards, old_shape):
return out return out
def get_old_shape(t, total_shards, dim=2):
if len(t.shape) == 2:
shard_shape = t.shape
if dim == 1:
assert shard_shape[0] % total_shards == 0
return (shard_shape[0] // total_shards, shard_shape[1])
elif dim == 2:
assert shard_shape[1] % total_shards == 0
return (shard_shape[0], shard_shape[1] // total_shards)
else:
raise ValueError(f"Unsupported dim {dim}")
if len(t.shape) == 1:
assert t.shape[0] % total_shards == 0
return (t.shape[0] // total_shards,)
else:
raise ValueError(f"Unsupported shape {t.shape}")
def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
assert config["cores_per_replica"] % checkpoint_shards == 0
output_shards = config["cores_per_replica"] // checkpoint_shards
import torch
from tqdm.auto import tqdm
move_xmap = jax.experimental.maps.xmap(
fun=lambda x, _: to_bf16(x),
in_axes=(["shard", ...], ["batch", ...]),
out_axes=["shard", ...],
axis_resources={'shard': 'mp', 'batch': 'dp'}
)
path_template = os.path.join(path, "layer_{layer:02d}-model_{shard:02d}-model_states.pt")
static_mapping = {
"word_embeddings.weight": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1},
"final_linear.weight": {"module": "projection_shard/~/linear", "param": "w", "axis": 2},
"norm.weight": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale", "axis": None},
"norm.bias": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset", "axis": None},
}
layer_mapping = {
"attention.query_key_value.weight": {"module": "combined_qkv", "param": "w", "axis": 2},
"attention.query_key_value.bias": {"module": "combined_qkv", "param": "b", "axis": 1},
"attention.dense.weight": {"module": "linear_3", "param": "w", "axis": 1},
"attention.dense.bias": {"module": "linear_3", "param": "b", "axis": None},
"mlp.dense_h_to_4h.weight": {"module": "linear_4", "param": "w", "axis": 2},
"mlp.dense_h_to_4h.bias": {"module": "linear_4", "param": "b", "axis": 1},
"mlp.dense_4h_to_h.weight": {"module": "linear_5", "param": "w", "axis": 1},
"mlp.dense_4h_to_h.bias": {"module": "linear_5", "param": "b", "axis": None},
"input_layernorm.weight": {"module": "replicated_layer_norm", "param": "scale", "axis": None},
"input_layernorm.bias": {"module": "replicated_layer_norm", "param": "offset", "axis": None},
"post_attention_layernorm.weight": {"module": "replicated_layer_norm_1", "param": "scale", "axis": None},
"post_attention_layernorm.bias": {"module": "replicated_layer_norm_1", "param": "offset", "axis": None},
}
tqdm_length = len(static_mapping) + config["layers"]*len(layer_mapping)
bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint")
for checkpoint_layer in range(config["layers"] + 5):
if checkpoint_layer in (1, config["layers"] + 2):
continue
layer = checkpoint_layer - 2
shards = []
for checkpoint_shard in range(checkpoint_shards):
shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu"))
for key in shards[0]:
if key == "attention.rotary_emb.inv_freq":
continue
elif key in static_mapping:
target_module = "causal_transformer_shard/~/" + static_mapping[key]["module"]
target_param = static_mapping[key]["param"]
target_axis = static_mapping[key]["axis"]
elif key in layer_mapping:
target_module = f"causal_transformer_shard/~/layer_{layer}/~/" + layer_mapping[key]["module"]
target_param = layer_mapping[key]["param"]
target_axis = layer_mapping[key]["axis"]
else:
error = f"{repr(key)} not found in mapping"
print("\n\nERROR: ", error, file=sys.stderr)
raise RuntimeError(error)
original_shape = shards[0][key].shape
for checkpoint_shard in range(checkpoint_shards):
if key in ("attention.dense.bias", "mlp.dense_4h_to_h.bias"):
shards[checkpoint_shard][key] /= config["cores_per_replica"]
if key != "word_embeddings.weight" and shards[checkpoint_shard][key].ndim == 2:
shards[checkpoint_shard][key] = shards[checkpoint_shard][key].T
tensor = shards[checkpoint_shard][key]
if target_axis is not None:
target_shape = (output_shards,) + get_old_shape(tensor, total_shards=output_shards, dim=target_axis)
else:
target_shape = (output_shards, tensor.shape[0])
shards[checkpoint_shard][key] = reshard_reverse(tensor.unsqueeze_(0), output_shards, target_shape)
#print(key, ":", original_shape, "->", shards[0][key].shape)
tensor = torch.cat([shards[s][key] for s in range(checkpoint_shards)], dim=0)
target_shape = state["params"][target_module][target_param].shape
if tensor.shape != target_shape:
error = f"Weight {repr(key)} has shape {tensor.shape} in checkpoint but shape {target_shape} was requested by MTJ for {target_module} {target_param}"
print("\n\nERROR: ", error, file=sys.stderr)
raise RuntimeError(error)
if tensor.dtype is torch.float16 or tensor.dtype is torch.float32:
tensor = tensor.bfloat16()
state["params"][target_module][target_param] = move_xmap(
jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(tensor)).copy(),
np.zeros(config["cores_per_replica"]),
)
bar.update(1)
for mk, mv in state["params"].items():
for pk, pv in mv.items():
if isinstance(pv, PlaceholderTensor):
error = f"{mk} {pk} could not be found in the model checkpoint"
print("\n\nERROR: " + error, file=sys.stderr)
raise RuntimeError(error)
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None: def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
global thread_resources_env, seq, tokenizer, network, params global thread_resources_env, seq, tokenizer, network, params
@ -820,6 +936,23 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
} }
params = kwargs params = kwargs
if vars.model == "TPUMeshTransformerGPTNeoX":
default_params = {
"compat": "neox",
"layers": 44,
"d_model": 6144,
"n_heads": 64,
"n_vocab": 50432,
"n_vocab_padding": 0,
"norm": "doublelayernorm",
"pe": "neox_rotary",
"pe_rotary_dims": 24,
"seq": 2048,
"cores_per_replica": 8,
"tokenizer_class": "GPT2TokenizerFast",
"tokenizer": "gpt2",
}
# Try to convert HF config.json to MTJ config # Try to convert HF config.json to MTJ config
if hf_checkpoint: if hf_checkpoint:
spec_path = os.path.join("maps", vars.model_type + ".json") spec_path = os.path.join("maps", vars.model_type + ".json")
@ -875,7 +1008,14 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
params[param] = default_params[param] params[param] = default_params[param]
# Load tokenizer # Load tokenizer
if not hf_checkpoint: if vars.model == "TPUMeshTransformerGPTNeoX":
tokenizer = Tokenizer.from_file(os.path.join(path, "20B_tokenizer.json"))
def new_encode(old_encode):
def encode(s, *args, **kwargs):
return old_encode(s).ids
return encode
tokenizer.encode = new_encode(tokenizer.encode)
elif not hf_checkpoint:
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")): if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'") raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"]) tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"])
@ -917,9 +1057,14 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
network = PenalizingCausalTransformer(params, dematerialized=True) network = PenalizingCausalTransformer(params, dematerialized=True)
if not hf_checkpoint: if not hf_checkpoint and vars.model != "TPUMeshTransformerGPTNeoX":
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1]) network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica)) #network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
return
if vars.model == "TPUMeshTransformerGPTNeoX":
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
read_neox_checkpoint(network.state, path, params)
return return
# Convert from HF checkpoint # Convert from HF checkpoint
@ -945,7 +1090,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
import torch_lazy_loader import torch_lazy_loader
import torch import torch
from tqdm import tqdm from tqdm.auto import tqdm
def callback(model_dict, f, **_): def callback(model_dict, f, **_):
with zipfile.ZipFile(f, "r") as z: with zipfile.ZipFile(f, "r") as z:
@ -1031,6 +1176,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
if os.path.isdir(vars.model.replace('/', '_')): if os.path.isdir(vars.model.replace('/', '_')):
import shutil import shutil
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_'))) shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
print("\n", flush=True)
with torch_lazy_loader.use_lazy_torch_load(callback=callback, dematerialized_modules=True): with torch_lazy_loader.use_lazy_torch_load(callback=callback, dematerialized_modules=True):
if(os.path.isdir(vars.custmodpth)): if(os.path.isdir(vars.custmodpth)):
try: try:
@ -1069,4 +1215,4 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
except Exception as e: except Exception as e:
model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache") model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache")
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica)) #network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))