mirror of
				https://github.com/KoboldAI/KoboldAI-Client.git
				synced 2025-06-05 21:59:24 +02:00 
			
		
		
		
	Merge pull request #101 from VE-FORBRYDERNE/neox
GPT-NeoX-20B support in Colab TPU instances
This commit is contained in:
		
							
								
								
									
										40
									
								
								aiserver.py
									
									
									
									
									
								
							
							
						
						
									
										40
									
								
								aiserver.py
									
									
									
									
									
								
							| @@ -196,6 +196,7 @@ class vars: | ||||
|     corescript  = "default.lua"  # Filename of corescript to load | ||||
|     # badwords    = []     # Array of str/chr values that should be removed from output | ||||
|     badwordsids = [[13460], [6880], [50256], [42496], [4613], [17414], [22039], [16410], [27], [29], [38430], [37922], [15913], [24618], [28725], [58], [47175], [36937], [26700], [12878], [16471], [37981], [5218], [29795], [13412], [45160], [3693], [49778], [4211], [20598], [36475], [33409], [44167], [32406], [29847], [29342], [42669], [685], [25787], [7359], [3784], [5320], [33994], [33490], [34516], [43734], [17635], [24293], [9959], [23785], [21737], [28401], [18161], [26358], [32509], [1279], [38155], [18189], [26894], [6927], [14610], [23834], [11037], [14631], [26933], [46904], [22330], [25915], [47934], [38214], [1875], [14692], [41832], [13163], [25970], [29565], [44926], [19841], [37250], [49029], [9609], [44438], [16791], [17816], [30109], [41888], [47527], [42924], [23984], [49074], [33717], [31161], [49082], [30138], [31175], [12240], [14804], [7131], [26076], [33250], [3556], [38381], [36338], [32756], [46581], [17912], [49146]] # Tokenized array of badwords used to prevent AI artifacting | ||||
|     badwordsids_neox = [[0], [1], [44162], [9502], [12520], [31841], [36320], [49824], [34417], [6038], [34494], [24815], [26635], [24345], [3455], [28905], [44270], [17278], [32666], [46880], [7086], [43189], [37322], [17778], [20879], [49821], [3138], [14490], [4681], [21391], [26786], [43134], [9336], [683], [48074], [41256], [19181], [29650], [28532], [36487], [45114], [46275], [16445], [15104], [11337], [1168], [5647], [29], [27482], [44965], [43782], [31011], [42944], [47389], [6334], [17548], [38329], [32044], [35487], [2239], [34761], [7444], [1084], [12399], [18990], [17636], [39083], [1184], [35830], [28365], [16731], [43467], [47744], [1138], [16079], [40116], [45564], [18297], [42368], [5456], [18022], [42696], [34476], [23505], [23741], [39334], [37944], [45382], [38709], [33440], [26077], [43600], [34418], [36033], [6660], [48167], [48471], [15775], [19884], [41533], [1008], [31053], [36692], [46576], [20095], [20629], [31759], [46410], [41000], [13488], [30952], [39258], [16160], [27655], [22367], [42767], [43736], [49694], [13811], [12004], [46768], [6257], [37471], [5264], [44153], [33805], [20977], [21083], [25416], [14277], [31096], [42041], [18331], [33376], [22372], [46294], [28379], [38475], [1656], [5204], [27075], [50001], [16616], [11396], [7748], [48744], [35402], [28120], [41512], [4207], [43144], [14767], [15640], [16595], [41305], [44479], [38958], [18474], [22734], [30522], [46267], [60], [13976], [31830], [48701], [39822], [9014], [21966], [31422], [28052], [34607], [2479], [3851], [32214], [44082], [45507], [3001], [34368], [34758], [13380], [38363], [4299], [46802], [30996], [12630], [49236], [7082], [8795], [5218], [44740], [9686], [9983], [45301], [27114], [40125], [1570], [26997], [544], [5290], [49193], [23781], [14193], [40000], [2947], [43781], [9102], [48064], [42274], [18772], [49384], [9884], [45635], [43521], [31258], [32056], [47686], [21760], [13143], [10148], [26119], [44308], [31379], [36399], [23983], [46694], [36134], [8562], [12977], [35117], [28591], [49021], [47093], [28653], [29013], [46468], [8605], [7254], [25896], [5032], [8168], [36893], [38270], [20499], [27501], [34419], [29547], [28571], [36586], [20871], [30537], [26842], [21375], [31148], [27618], [33094], [3291], [31789], [28391], [870], [9793], [41361], [47916], [27468], [43856], [8850], [35237], [15707], [47552], [2730], [41449], [45488], [3073], [49806], [21938], [24430], [22747], [20924], [46145], [20481], [20197], [8239], [28231], [17987], [42804], [47269], [29972], [49884], [21382], [46295], [36676], [34616], [3921], [26991], [27720], [46265], [654], [9855], [40354], [5291], [34904], [44342], [2470], [14598], [880], [19282], [2498], [24237], [21431], [16369], [8994], [44524], [45662], [13663], [37077], [1447], [37786], [30863], [42854], [1019], [20322], [4398], [12159], [44072], [48664], [31547], [18736], [9259], [31], [16354], [21810], [4357], [37982], [5064], [2033], [32871], [47446], [62], [22158], [37387], [8743], [47007], [17981], [11049], [4622], [37916], [36786], [35138], [29925], [14157], [18095], [27829], [1181], [22226], [5709], [4725], [30189], [37014], [1254], [11380], [42989], [696], [24576], [39487], [30119], [1092], [8088], [2194], [9899], [14412], [21828], [3725], [13544], [5180], [44679], [34398], [3891], [28739], [14219], [37594], [49550], [11326], [6904], [17266], [5749], [10174], [23405], [9955], [38271], [41018], [13011], [48392], [36784], [24254], [21687], [23734], [5413], [41447], [45472], [10122], [17555], [15830], [47384], [12084], [31350], [47940], [11661], [27988], [45443], [905], [49651], [16614], [34993], [6781], [30803], [35869], [8001], [41604], [28118], [46462], [46762], [16262], [17281], [5774], [10943], [5013], [18257], [6750], [4713], [3951], [11899], [38791], [16943], [37596], [9318], [18413], [40473], [13208], [16375]] | ||||
|     deletewi    = None   # Temporary storage for UID to delete | ||||
|     wirmvwhtsp  = False  # Whether to remove leading whitespace from WI entries | ||||
|     widepth     = 3      # How many historical actions to scan for WI hits | ||||
| @@ -317,7 +318,7 @@ def getmodelname(): | ||||
|     if(args.configname): | ||||
|        modelname = args.configname | ||||
|        return modelname | ||||
|     if(vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ")): | ||||
|     if(vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): | ||||
|         modelname = os.path.basename(os.path.normpath(vars.custmodpth)) | ||||
|         return modelname | ||||
|     else: | ||||
| @@ -699,7 +700,7 @@ def spRequest(filename): | ||||
|     vars.sp_length = tensor.shape[-2] | ||||
|     vars.spmeta["n_tokens"] = vars.sp_length | ||||
|  | ||||
|     if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): | ||||
|     if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): | ||||
|         rows = tensor.shape[0] | ||||
|         padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows | ||||
|         tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) | ||||
| @@ -790,7 +791,7 @@ else: | ||||
|     getModelSelection(mainmenu) | ||||
|  | ||||
| # If transformers model was selected & GPU available, ask to use CPU or GPU | ||||
| if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]): | ||||
| if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]): | ||||
|     vars.allowsp = True | ||||
|     # Test for GPU support | ||||
|     import torch | ||||
| @@ -830,7 +831,7 @@ if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMe | ||||
|         print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)") | ||||
|         vars.model_type = "gpt_neo" | ||||
|  | ||||
| if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]): | ||||
| if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]): | ||||
|     loadmodelsettings() | ||||
|     loadsettings() | ||||
|     print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="") | ||||
| @@ -1032,7 +1033,7 @@ socketio = SocketIO(app, async_method="eventlet") | ||||
| print("{0}OK!{1}".format(colors.GREEN, colors.END)) | ||||
|  | ||||
| # Start transformers and create pipeline | ||||
| if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]): | ||||
| if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]): | ||||
|     if(not vars.noai): | ||||
|         print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END)) | ||||
|         from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer | ||||
| @@ -1050,7 +1051,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go | ||||
|             if not vars.lazy_load: | ||||
|                 return | ||||
|  | ||||
|             from tqdm import tqdm | ||||
|             from tqdm.auto import tqdm | ||||
|  | ||||
|             if "breakmodel" in globals(): | ||||
|                 gpu_blocks = breakmodel.gpu_blocks | ||||
| @@ -1380,6 +1381,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go | ||||
|             if os.path.isdir(vars.model.replace('/', '_')): | ||||
|                 import shutil | ||||
|                 shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_'))) | ||||
|             print("\n", flush=True) | ||||
|             with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(model_config.num_layers if hasattr(model_config, "num_layers") else model_config.n_layer) if vars.lazy_load else None, dematerialized_modules=True): | ||||
|                 if(vars.lazy_load):  # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time | ||||
|                     lowmem = {} | ||||
| @@ -1553,11 +1555,15 @@ else: | ||||
|         tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/") | ||||
|         loadsettings() | ||||
|     # Load the TPU backend if requested | ||||
|     elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): | ||||
|     elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): | ||||
|         if(vars.model == "TPUMeshTransformerGPTNeoX"): | ||||
|             vars.badwordsids = vars.badwordsids_neox | ||||
|         print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END)) | ||||
|         if vars.model == "TPUMeshTransformerGPTJ" and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)): | ||||
|         if vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)): | ||||
|             raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder") | ||||
|         import tpu_mtj_backend | ||||
|         if(vars.model == "TPUMeshTransformerGPTNeoX"): | ||||
|             tpu_mtj_backend.pad_token_id = 1 | ||||
|         tpu_mtj_backend.vars = vars | ||||
|         tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback | ||||
|         tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback | ||||
| @@ -1567,7 +1573,7 @@ else: | ||||
|         vars.allowsp = True | ||||
|         loadmodelsettings() | ||||
|         loadsettings() | ||||
|         tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=vars.model != "TPUMeshTransformerGPTJ" and vars.use_colab_tpu, **vars.modelconfig) | ||||
|         tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and vars.use_colab_tpu, **vars.modelconfig) | ||||
|         vars.modeldim = int(tpu_mtj_backend.params["d_model"]) | ||||
|         tokenizer = tpu_mtj_backend.tokenizer | ||||
|     else: | ||||
| @@ -2098,7 +2104,7 @@ def lua_get_modeltype(): | ||||
|         return "readonly" | ||||
|     if(vars.model in ("Colab", "OAI", "InferKit")): | ||||
|         return "api" | ||||
|     if(not vars.use_colab_tpu and vars.model not in ("TPUMeshTransformerGPTJ",) and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))): | ||||
|     if(not vars.use_colab_tpu and vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))): | ||||
|         hidden_size = get_hidden_size_from_model(model) | ||||
|     if(vars.model in ("gpt2",) or (vars.model_type == "gpt2" and hidden_size == 768)): | ||||
|         return "gpt2" | ||||
| @@ -2127,7 +2133,7 @@ def lua_get_modelbackend(): | ||||
|         return "readonly" | ||||
|     if(vars.model in ("Colab", "OAI", "InferKit")): | ||||
|         return "api" | ||||
|     if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): | ||||
|     if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): | ||||
|         return "mtj" | ||||
|     return "transformers" | ||||
|  | ||||
| @@ -2136,7 +2142,7 @@ def lua_get_modelbackend(): | ||||
| #==================================================================# | ||||
| @bridged_kwarg() | ||||
| def lua_is_custommodel(): | ||||
|     return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ") | ||||
|     return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") | ||||
|  | ||||
| #==================================================================# | ||||
| #   | ||||
| @@ -3074,22 +3080,22 @@ def calcsubmit(txt): | ||||
|     if(vars.model != "InferKit"): | ||||
|         subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions, submission=txt) | ||||
|         if(actionlen == 0): | ||||
|             if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): | ||||
|             if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]): | ||||
|                 generate(subtxt, min, max, found_entries=found_entries) | ||||
|             elif(vars.model == "Colab"): | ||||
|                 sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) | ||||
|             elif(vars.model == "OAI"): | ||||
|                 oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) | ||||
|             elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): | ||||
|             elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): | ||||
|                 tpumtjgenerate(subtxt, min, max, found_entries=found_entries) | ||||
|         else: | ||||
|             if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): | ||||
|             if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]): | ||||
|                 generate(subtxt, min, max, found_entries=found_entries) | ||||
|             elif(vars.model == "Colab"): | ||||
|                 sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) | ||||
|             elif(vars.model == "OAI"): | ||||
|                 oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) | ||||
|             elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): | ||||
|             elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): | ||||
|                 tpumtjgenerate(subtxt, min, max, found_entries=found_entries) | ||||
|                      | ||||
|     # For InferKit web API | ||||
| @@ -5105,7 +5111,7 @@ if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")): | ||||
|     file.close() | ||||
|  | ||||
| # Precompile TPU backend if required | ||||
| if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): | ||||
| if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): | ||||
|     soft_tokens = tpumtjgetsofttokens() | ||||
|     if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)): | ||||
|         threading.Thread( | ||||
|   | ||||
| @@ -46,6 +46,7 @@ import numpy as np | ||||
| import optax | ||||
| import haiku as hk | ||||
| from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM | ||||
| from tokenizers import Tokenizer | ||||
| from mesh_transformer.checkpoint import read_ckpt_lowmem | ||||
| from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard, PlaceholderTensor | ||||
| from mesh_transformer.util import to_bf16 | ||||
| @@ -800,6 +801,121 @@ def reshard_reverse(x, total_shards, old_shape): | ||||
|     return out | ||||
|  | ||||
|  | ||||
| def get_old_shape(t, total_shards, dim=2): | ||||
|     if len(t.shape) == 2: | ||||
|         shard_shape = t.shape | ||||
|         if dim == 1: | ||||
|             assert shard_shape[0] % total_shards == 0 | ||||
|             return (shard_shape[0] // total_shards, shard_shape[1]) | ||||
|         elif dim == 2: | ||||
|             assert shard_shape[1] % total_shards == 0 | ||||
|             return (shard_shape[0], shard_shape[1] // total_shards) | ||||
|         else: | ||||
|             raise ValueError(f"Unsupported dim {dim}") | ||||
|     if len(t.shape) == 1: | ||||
|         assert t.shape[0] % total_shards == 0 | ||||
|         return (t.shape[0] // total_shards,) | ||||
|     else: | ||||
|         raise ValueError(f"Unsupported shape {t.shape}") | ||||
|  | ||||
|  | ||||
| def read_neox_checkpoint(state, path, config, checkpoint_shards=2): | ||||
|     assert config["cores_per_replica"] % checkpoint_shards == 0 | ||||
|     output_shards = config["cores_per_replica"] // checkpoint_shards | ||||
|  | ||||
|     import torch | ||||
|     from tqdm.auto import tqdm | ||||
|  | ||||
|     move_xmap = jax.experimental.maps.xmap( | ||||
|         fun=lambda x, _: to_bf16(x), | ||||
|         in_axes=(["shard", ...], ["batch", ...]), | ||||
|         out_axes=["shard", ...], | ||||
|         axis_resources={'shard': 'mp', 'batch': 'dp'} | ||||
|     ) | ||||
|  | ||||
|     path_template = os.path.join(path, "layer_{layer:02d}-model_{shard:02d}-model_states.pt") | ||||
|  | ||||
|     static_mapping = { | ||||
|         "word_embeddings.weight": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1}, | ||||
|         "final_linear.weight": {"module": "projection_shard/~/linear", "param": "w", "axis": 2}, | ||||
|         "norm.weight": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale", "axis": None}, | ||||
|         "norm.bias": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset", "axis": None}, | ||||
|     } | ||||
|  | ||||
|     layer_mapping = { | ||||
|         "attention.query_key_value.weight": {"module": "combined_qkv", "param": "w", "axis": 2}, | ||||
|         "attention.query_key_value.bias": {"module": "combined_qkv", "param": "b", "axis": 1}, | ||||
|         "attention.dense.weight": {"module": "linear_3", "param": "w", "axis": 1}, | ||||
|         "attention.dense.bias": {"module": "linear_3", "param": "b", "axis": None}, | ||||
|         "mlp.dense_h_to_4h.weight": {"module": "linear_4", "param": "w", "axis": 2}, | ||||
|         "mlp.dense_h_to_4h.bias": {"module": "linear_4", "param": "b", "axis": 1}, | ||||
|         "mlp.dense_4h_to_h.weight": {"module": "linear_5", "param": "w", "axis": 1}, | ||||
|         "mlp.dense_4h_to_h.bias": {"module": "linear_5", "param": "b", "axis": None}, | ||||
|         "input_layernorm.weight": {"module": "replicated_layer_norm", "param": "scale", "axis": None}, | ||||
|         "input_layernorm.bias": {"module": "replicated_layer_norm", "param": "offset", "axis": None}, | ||||
|         "post_attention_layernorm.weight": {"module": "replicated_layer_norm_1", "param": "scale", "axis": None}, | ||||
|         "post_attention_layernorm.bias": {"module": "replicated_layer_norm_1", "param": "offset", "axis": None}, | ||||
|     } | ||||
|  | ||||
|     tqdm_length = len(static_mapping) + config["layers"]*len(layer_mapping) | ||||
|     bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint") | ||||
|  | ||||
|     for checkpoint_layer in range(config["layers"] + 5): | ||||
|         if checkpoint_layer in (1, config["layers"] + 2): | ||||
|             continue | ||||
|         layer = checkpoint_layer - 2 | ||||
|         shards = [] | ||||
|         for checkpoint_shard in range(checkpoint_shards): | ||||
|             shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu")) | ||||
|         for key in shards[0]: | ||||
|             if key == "attention.rotary_emb.inv_freq": | ||||
|                 continue | ||||
|             elif key in static_mapping: | ||||
|                 target_module = "causal_transformer_shard/~/" + static_mapping[key]["module"] | ||||
|                 target_param = static_mapping[key]["param"] | ||||
|                 target_axis = static_mapping[key]["axis"] | ||||
|             elif key in layer_mapping: | ||||
|                 target_module = f"causal_transformer_shard/~/layer_{layer}/~/" + layer_mapping[key]["module"] | ||||
|                 target_param = layer_mapping[key]["param"] | ||||
|                 target_axis = layer_mapping[key]["axis"] | ||||
|             else: | ||||
|                 error = f"{repr(key)} not found in mapping" | ||||
|                 print("\n\nERROR: ", error, file=sys.stderr) | ||||
|                 raise RuntimeError(error) | ||||
|             original_shape = shards[0][key].shape | ||||
|             for checkpoint_shard in range(checkpoint_shards): | ||||
|                 if key in ("attention.dense.bias", "mlp.dense_4h_to_h.bias"): | ||||
|                     shards[checkpoint_shard][key] /= config["cores_per_replica"] | ||||
|                 if key != "word_embeddings.weight" and shards[checkpoint_shard][key].ndim == 2: | ||||
|                     shards[checkpoint_shard][key] = shards[checkpoint_shard][key].T | ||||
|                 tensor = shards[checkpoint_shard][key] | ||||
|                 if target_axis is not None: | ||||
|                     target_shape = (output_shards,) + get_old_shape(tensor, total_shards=output_shards, dim=target_axis) | ||||
|                 else: | ||||
|                     target_shape = (output_shards, tensor.shape[0]) | ||||
|                 shards[checkpoint_shard][key] = reshard_reverse(tensor.unsqueeze_(0), output_shards, target_shape) | ||||
|             #print(key, ":", original_shape, "->", shards[0][key].shape) | ||||
|             tensor = torch.cat([shards[s][key] for s in range(checkpoint_shards)], dim=0) | ||||
|             target_shape = state["params"][target_module][target_param].shape | ||||
|             if tensor.shape != target_shape: | ||||
|                 error = f"Weight {repr(key)} has shape {tensor.shape} in checkpoint but shape {target_shape} was requested by MTJ for {target_module} {target_param}" | ||||
|                 print("\n\nERROR: ", error, file=sys.stderr) | ||||
|                 raise RuntimeError(error) | ||||
|             if tensor.dtype is torch.float16 or tensor.dtype is torch.float32: | ||||
|                 tensor = tensor.bfloat16() | ||||
|             state["params"][target_module][target_param] = move_xmap( | ||||
|                 jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(tensor)).copy(), | ||||
|                 np.zeros(config["cores_per_replica"]), | ||||
|             ) | ||||
|             bar.update(1) | ||||
|     for mk, mv in state["params"].items(): | ||||
|         for pk, pv in mv.items(): | ||||
|             if isinstance(pv, PlaceholderTensor): | ||||
|                 error = f"{mk} {pk} could not be found in the model checkpoint" | ||||
|                 print("\n\nERROR:  " + error, file=sys.stderr) | ||||
|                 raise RuntimeError(error) | ||||
|  | ||||
|  | ||||
| def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None: | ||||
|     global thread_resources_env, seq, tokenizer, network, params | ||||
|  | ||||
| @@ -820,6 +936,23 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo | ||||
|     } | ||||
|     params = kwargs | ||||
|  | ||||
|     if vars.model == "TPUMeshTransformerGPTNeoX": | ||||
|         default_params = { | ||||
|             "compat": "neox", | ||||
|             "layers": 44, | ||||
|             "d_model": 6144, | ||||
|             "n_heads": 64, | ||||
|             "n_vocab": 50432, | ||||
|             "n_vocab_padding": 0, | ||||
|             "norm": "doublelayernorm", | ||||
|             "pe": "neox_rotary", | ||||
|             "pe_rotary_dims": 24, | ||||
|             "seq": 2048, | ||||
|             "cores_per_replica": 8, | ||||
|             "tokenizer_class": "GPT2TokenizerFast", | ||||
|             "tokenizer": "gpt2", | ||||
|         } | ||||
|  | ||||
|     # Try to convert HF config.json to MTJ config | ||||
|     if hf_checkpoint: | ||||
|         spec_path = os.path.join("maps", vars.model_type + ".json") | ||||
| @@ -875,7 +1008,14 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo | ||||
|             params[param] = default_params[param] | ||||
|  | ||||
|     # Load tokenizer | ||||
|     if not hf_checkpoint: | ||||
|     if vars.model == "TPUMeshTransformerGPTNeoX": | ||||
|         tokenizer = Tokenizer.from_file(os.path.join(path, "20B_tokenizer.json")) | ||||
|         def new_encode(old_encode): | ||||
|             def encode(s, *args, **kwargs): | ||||
|                 return old_encode(s).ids | ||||
|             return encode | ||||
|         tokenizer.encode = new_encode(tokenizer.encode) | ||||
|     elif not hf_checkpoint: | ||||
|         if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")): | ||||
|             raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'") | ||||
|         tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"]) | ||||
| @@ -917,9 +1057,14 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo | ||||
|  | ||||
|     network = PenalizingCausalTransformer(params, dematerialized=True) | ||||
|  | ||||
|     if not hf_checkpoint: | ||||
|     if not hf_checkpoint and vars.model != "TPUMeshTransformerGPTNeoX": | ||||
|         network.state = read_ckpt_lowmem(network.state, path, devices.shape[1]) | ||||
|         network.state = network.move_xmap(network.state, np.zeros(cores_per_replica)) | ||||
|         #network.state = network.move_xmap(network.state, np.zeros(cores_per_replica)) | ||||
|         return | ||||
|  | ||||
|     if vars.model == "TPUMeshTransformerGPTNeoX": | ||||
|         print("\n\n\nThis model has  ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), "  parameters.\n") | ||||
|         read_neox_checkpoint(network.state, path, params) | ||||
|         return | ||||
|  | ||||
|     # Convert from HF checkpoint | ||||
| @@ -945,7 +1090,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo | ||||
|  | ||||
|     import torch_lazy_loader | ||||
|     import torch | ||||
|     from tqdm import tqdm | ||||
|     from tqdm.auto import tqdm | ||||
|  | ||||
|     def callback(model_dict, f, **_): | ||||
|         with zipfile.ZipFile(f, "r") as z: | ||||
| @@ -1031,6 +1176,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo | ||||
|     if os.path.isdir(vars.model.replace('/', '_')): | ||||
|         import shutil | ||||
|         shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_'))) | ||||
|     print("\n", flush=True) | ||||
|     with torch_lazy_loader.use_lazy_torch_load(callback=callback, dematerialized_modules=True): | ||||
|         if(os.path.isdir(vars.custmodpth)): | ||||
|             try: | ||||
| @@ -1069,4 +1215,4 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo | ||||
|             except Exception as e: | ||||
|                 model     = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache") | ||||
|  | ||||
|     network.state = network.move_xmap(network.state, np.zeros(cores_per_replica)) | ||||
|     #network.state = network.move_xmap(network.state, np.zeros(cores_per_replica)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user