diff --git a/aiserver.py b/aiserver.py index 2652ea4b..8444080a 100644 --- a/aiserver.py +++ b/aiserver.py @@ -26,6 +26,8 @@ import traceback import threading import markdown import bleach +import itertools +import bisect from collections.abc import Iterable from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List @@ -250,6 +252,8 @@ class vars: newlinemode = "n" quiet = False # If set will suppress any story text from being printed to the console (will only be seen on the client web page) debug = False # If set to true, will send debug information to the client for display + lazy_load = True # Whether or not to use torch_lazy_loader.py for transformers models in order to reduce CPU memory usage + use_colab_tpu = os.environ.get("COLAB_TPU_ADDR", "") != "" # Whether or not we're in a Colab TPU instance and are going to use the TPU rather than the CPU utils.vars = vars @@ -337,10 +341,10 @@ def device_list(n_layers, primary=None, selected=None): sep_color = colors.YELLOW print(f"{row_color} {' '*9} N/A {sep_color}|{row_color} {n_layers:3} {sep_color}|{row_color} (CPU){colors.END}") -def device_config(model): +def device_config(config): global breakmodel, generator import breakmodel - n_layers = model.config.num_layers if hasattr(model.config, "num_layers") else model.config.n_layer + n_layers = config.num_layers if hasattr(config, "num_layers") else config.n_layer if(args.breakmodel_gpulayers is not None): try: breakmodel.gpu_blocks = list(map(int, args.breakmodel_gpulayers.split(','))) @@ -413,22 +417,30 @@ def device_config(model): # If all layers are on the same device, use the old GPU generation mode while(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] == 0): breakmodel.gpu_blocks.pop() - if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, model.config.num_layers if hasattr(model.config, "num_layers") else model.config.n_layer)): + if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, config.num_layers if hasattr(config, "num_layers") else config.n_layer)): vars.breakmodel = False vars.usegpu = True vars.gpu_device = len(breakmodel.gpu_blocks)-1 - model = model.half().to(vars.gpu_device) - generator = model.generate return if(not breakmodel.gpu_blocks): print("Nothing assigned to a GPU, reverting to CPU only mode") vars.breakmodel = False vars.usegpu = False - model = model.to('cpu').float() + return + +def move_model_to_devices(model): + global generator + + if(not vars.breakmodel): + if(vars.usegpu): + model = model.half().to(vars.gpu_device) + else: + model = model.to('cpu').float() generator = model.generate return - model.half().to('cpu') + + model.half() gc.collect() if(hasattr(model, "transformer")): model.transformer.wte.to(breakmodel.primary_device) @@ -684,7 +696,7 @@ def spRequest(filename): vars.sp_length = tensor.shape[-2] vars.spmeta["n_tokens"] = vars.sp_length - if(vars.model in ("TPUMeshTransformerGPTJ",)): + if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): rows = tensor.shape[0] padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) @@ -756,6 +768,9 @@ if args.ngrok: if args.host: vars.host = True; +if args.cpu: + vars.use_colab_tpu = False + vars.smandelete = vars.host == args.override_delete vars.smanrename = vars.host == args.override_rename @@ -772,7 +787,7 @@ else: getModelSelection(mainmenu) # If transformers model was selected & GPU available, ask to use CPU or GPU -if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): +if(vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): vars.allowsp = True # Test for GPU support import torch @@ -811,6 +826,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme elif(vars.model_type == "not_found"): print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)") vars.model_type = "gpt_neo" + +if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): loadmodelsettings() loadsettings() print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="") @@ -1003,7 +1020,7 @@ socketio = SocketIO(app, async_method="eventlet") print("{0}OK!{1}".format(colors.GREEN, colors.END)) # Start transformers and create pipeline -if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): +if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): if(not vars.noai): print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END)) from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer @@ -1015,6 +1032,71 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme import transformers.generation_utils from transformers import __version__ as transformers_version + # Lazy loader + import torch_lazy_loader + def get_lazy_load_callback(n_layers, convert_to_float16=True): + if not vars.lazy_load: + return + + from tqdm import tqdm + + if "breakmodel" in globals(): + gpu_blocks = breakmodel.gpu_blocks + ram_blocks = ram_blocks = n_layers - sum(gpu_blocks) + cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks)) + else: + ram_blocks = gpu_blocks = cumulative_gpu_blocks = None + + def lazy_load_callback(model_dict, f, **_): + device_map = {} + + for _key, spec in lazy_load_spec.get("layer_weights", {}).items(): + for layer in range(n_layers): + key = _key.format(layer=layer) + if key not in model_dict: + continue + device = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" if not vars.hascuda or not vars.breakmodel or layer < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks) + device_map[key] = device + + for key, value in model_dict.items(): + if isinstance(value, torch_lazy_loader.LazyTensor) and key not in device_map: + device_map[key] = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" + + with zipfile.ZipFile(f, "r") as z: + try: + last_storage_key = None + f = None + for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"): + storage_key = model_dict[key].key + if storage_key != last_storage_key: + last_storage_key = storage_key + if isinstance(f, zipfile.ZipExtFile): + f.close() + f = z.open(f"archive/data/{storage_key}") + current_offset = f.tell() + if current_offset != model_dict[key].seek_offset: + f.seek(model_dict[key].seek_offset - current_offset, 1) + device = device_map[key] + #print(f"Transferring <{key}> to {'(CPU)' if device == 'cpu' else '[device ' + str(device) + ']'} ... ", end="", flush=True) + model_dict[key] = model_dict[key].materialize(f, map_location="cpu") + if convert_to_float16 and vars.hascuda and (vars.breakmodel or vars.usegpu) and model_dict[key].dtype is torch.float32: + model_dict[key] = model_dict[key].to(torch.float16) + model_dict[key] = model_dict[key].to(device) + #print("OK", flush=True) + finally: + if isinstance(f, zipfile.ZipExtFile): + f.close() + + return lazy_load_callback + + lazy_load_config_path = os.path.join(path.dirname(path.realpath(__file__)), "maps", vars.model_type + ".json") + if(vars.lazy_load and "model_config" in globals() and os.path.isfile(lazy_load_config_path)): + with open(lazy_load_config_path) as f: + lazy_load_spec = json.load(f) + + else: + vars.lazy_load = False + # Some versions of transformers 4.17.0.dev0 are affected by # https://github.com/huggingface/transformers/issues/15736 # This is a workaround for those versions of transformers. @@ -1250,6 +1332,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme # If custom GPT2 model was chosen if(vars.model == "GPT2Custom"): + vars.lazy_load = False model_config = open(vars.custmodpth + "/config.json", "r") js = json.load(model_config) with(maybe_use_float16()): @@ -1271,6 +1354,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme # feature yet if(vars.model_type == "gpt2"): lowmem = {} + + # If we're using torch_lazy_loader, we need to get breakmodel config + # early so that it knows where to load the individual model tensors + if(vars.lazy_load and vars.hascuda and vars.breakmodel): + device_config(model_config) # Download model from Huggingface if it does not exist, otherwise load locally @@ -1278,43 +1366,43 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme if os.path.isdir(vars.model.replace('/', '_')): import shutil shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_'))) - if(os.path.isdir(vars.custmodpth)): - with(maybe_use_float16()): - try: - tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache") - except ValueError as e: - tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache") - try: - model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem) - except ValueError as e: - model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem) - elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))): - with(maybe_use_float16()): - try: - tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") - except ValueError as e: - tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") - try: - model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) - except ValueError as e: - model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) - else: - try: - tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache") - except ValueError as e: - tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache") - with(maybe_use_float16()): + with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(model_config.num_layers if hasattr(model_config, "num_layers") else model_config.n_layer), dematerialized_modules=True): + if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time + lowmem = {} + if(os.path.isdir(vars.custmodpth)): + try: + tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache") + except ValueError as e: + tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache") + try: + model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem) + except ValueError as e: + model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem) + elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))): + try: + tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") + except ValueError as e: + tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") + try: + model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) + except ValueError as e: + model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) + else: + try: + tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache") + except ValueError as e: + tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache") try: model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache", **lowmem) except ValueError as e: model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache", **lowmem) - - if not args.colab: - import shutil - model = model.half() - model.save_pretrained("models/{}".format(vars.model.replace('/', '_'))) - tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_'))) - shutil.rmtree("cache/") + + if not args.colab: + import shutil + model = model.half() + model.save_pretrained("models/{}".format(vars.model.replace('/', '_'))) + tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_'))) + shutil.rmtree("cache/") if(vars.hascuda): if(vars.usegpu): @@ -1323,7 +1411,9 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme generator = model.generate elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel) vars.modeldim = get_hidden_size_from_model(model) - device_config(model) + if(not vars.lazy_load): + device_config(model.config) + move_model_to_devices(model) else: model = model.to('cpu').float() vars.modeldim = get_hidden_size_from_model(model) @@ -1440,9 +1530,9 @@ else: tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/") loadsettings() # Load the TPU backend if requested - elif(vars.model == "TPUMeshTransformerGPTJ"): + elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END)) - if not vars.custmodpth or not os.path.isdir(vars.custmodpth): + if vars.model == "TPUMeshTransformerGPTJ" and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)): raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder") import tpu_mtj_backend tpu_mtj_backend.vars = vars @@ -1454,7 +1544,7 @@ else: vars.allowsp = True loadmodelsettings() loadsettings() - tpu_mtj_backend.load_model(vars.custmodpth, **vars.modelconfig) + tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=vars.model != "TPUMeshTransformerGPTJ" and vars.use_colab_tpu, **vars.modelconfig) vars.modeldim = int(tpu_mtj_backend.params["d_model"]) tokenizer = tpu_mtj_backend.tokenizer else: @@ -1985,7 +2075,7 @@ def lua_get_modeltype(): return "readonly" if(vars.model in ("Colab", "OAI", "InferKit")): return "api" - if(vars.model not in ("TPUMeshTransformerGPTJ",) and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))): + if(not vars.use_colab_tpu and vars.model not in ("TPUMeshTransformerGPTJ",) and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))): hidden_size = get_hidden_size_from_model(model) if(vars.model in ("gpt2",) or (vars.model_type == "gpt2" and hidden_size == 768)): return "gpt2" @@ -2001,7 +2091,7 @@ def lua_get_modeltype(): return "gpt-neo-1.3B" if(vars.model in ("EleutherAI/gpt-neo-2.7B",) or (vars.model_type == "gpt_neo" and hidden_size == 2560)): return "gpt-neo-2.7B" - if(vars.model in ("EleutherAI/gpt-j-6B",) or (vars.model == "TPUMeshTransformerGPTJ" and tpu_mtj_backend.params["d_model"] == 4096) or (vars.model_type in ("gpt_neo", "gptj") and hidden_size == 4096)): + if(vars.model in ("EleutherAI/gpt-j-6B",) or ((vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ") and tpu_mtj_backend.params["d_model"] == 4096) or (vars.model_type in ("gpt_neo", "gptj") and hidden_size == 4096)): return "gpt-j-6B" return "unknown" @@ -2014,7 +2104,7 @@ def lua_get_modelbackend(): return "readonly" if(vars.model in ("Colab", "OAI", "InferKit")): return "api" - if(vars.model in ("TPUMeshTransformerGPTJ",)): + if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): return "mtj" return "transformers" @@ -2961,22 +3051,22 @@ def calcsubmit(txt): if(vars.model != "InferKit"): subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions, submission=txt) if(actionlen == 0): - if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): + if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): generate(subtxt, min, max, found_entries=found_entries) elif(vars.model == "Colab"): sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) elif(vars.model == "OAI"): oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) - elif(vars.model == "TPUMeshTransformerGPTJ"): + elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): tpumtjgenerate(subtxt, min, max, found_entries=found_entries) else: - if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): + if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): generate(subtxt, min, max, found_entries=found_entries) elif(vars.model == "Colab"): sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) elif(vars.model == "OAI"): oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) - elif(vars.model == "TPUMeshTransformerGPTJ"): + elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): tpumtjgenerate(subtxt, min, max, found_entries=found_entries) # For InferKit web API @@ -4987,7 +5077,7 @@ if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")): file.close() # Precompile TPU backend if required -if(vars.model in ("TPUMeshTransformerGPTJ",)): +if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): soft_tokens = tpumtjgetsofttokens() if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)): threading.Thread( diff --git a/maps/gpt_neo.json b/maps/gpt_neo.json new file mode 100644 index 00000000..fa2d4084 --- /dev/null +++ b/maps/gpt_neo.json @@ -0,0 +1,32 @@ +{ + "mtj_compat": "neo", + "mtj_pe": "fixed", + "mtj_config_map": { + "d_model": "hidden_size", + "n_heads": "num_heads", + "layers": "num_layers" + }, + "static_weights": { + "transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}}, + "transformer.wpe.weight": {"mtj": {"module": "embedding_shard", "param": "pos_embs", "transforms": ["no_transpose"]}}, + "transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}}, + "transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}} + }, + "layer_weights": { + "transformer.h.{layer}.attn.attention.bias": {}, + "transformer.h.{layer}.attn.attention.masked_bias": {}, + "transformer.h.{layer}.attn.attention.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}}, + "transformer.h.{layer}.attn.attention.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}}, + "transformer.h.{layer}.attn.attention.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}}, + "transformer.h.{layer}.attn.attention.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}}, + "transformer.h.{layer}.attn.attention.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}}, + "transformer.h.{layer}.mlp.c_fc.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}}, + "transformer.h.{layer}.mlp.c_fc.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}}, + "transformer.h.{layer}.mlp.c_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}}, + "transformer.h.{layer}.mlp.c_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}}, + "transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}}, + "transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}}, + "transformer.h.{layer}.ln_2.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "scale"}}, + "transformer.h.{layer}.ln_2.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "offset"}} + } +} diff --git a/maps/gptj.json b/maps/gptj.json new file mode 100644 index 00000000..8e0bc9da --- /dev/null +++ b/maps/gptj.json @@ -0,0 +1,32 @@ +{ + "mtj_compat": "j", + "mtj_pe": "rotary", + "mtj_config_map": { + "pe_rotary_dims": ["rotary_dim", 64], + "d_model": "n_embd", + "n_heads": "n_head", + "layers": "n_layer" + }, + "static_weights": { + "transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}}, + "transformer.wte.bias": {"mtj": {"module": "embedding_shard/~/linear", "param": "b"}}, + "transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}}, + "transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}}, + "lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "transforms": ["vocab_pad"]}}, + "lm_head.bias": {"mtj": {"module": "projection_shard/~/linear", "param": "b"}} + }, + "layer_weights": { + "transformer.h.{layer}.attn.bias": {}, + "transformer.h.{layer}.attn.masked_bias": {}, + "transformer.h.{layer}.attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}}, + "transformer.h.{layer}.attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}}, + "transformer.h.{layer}.attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}}, + "transformer.h.{layer}.attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}}, + "transformer.h.{layer}.mlp.fc_in.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}}, + "transformer.h.{layer}.mlp.fc_in.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}}, + "transformer.h.{layer}.mlp.fc_out.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}}, + "transformer.h.{layer}.mlp.fc_out.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}}, + "transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}}, + "transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}} + } +} diff --git a/maps/xglm.json b/maps/xglm.json new file mode 100644 index 00000000..65ab5e5e --- /dev/null +++ b/maps/xglm.json @@ -0,0 +1,33 @@ +{ + "mtj_compat": "fairseq_lm", + "mtj_pe": "fairseq_sinusoidal", + "mtj_config_map": { + "d_model": "d_model", + "n_heads": "attention_heads", + "layers": "num_layers" + }, + "static_weights": { + "model.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}}, + "model.layer_norm.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}}, + "model.layer_norm.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}}, + "lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "transforms": ["vocab_pad"]}} + }, + "layer_weights": { + "model.layers.{layer}.self_attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}}, + "model.layers.{layer}.self_attn.q_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear", "param": "b"}}, + "model.layers.{layer}.self_attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}}, + "model.layers.{layer}.self_attn.v_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "b"}}, + "model.layers.{layer}.self_attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}}, + "model.layers.{layer}.self_attn.k_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "b"}}, + "model.layers.{layer}.self_attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}}, + "model.layers.{layer}.self_attn.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}}, + "model.layers.{layer}.fc1.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}}, + "model.layers.{layer}.fc1.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}}, + "model.layers.{layer}.fc2.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}}, + "model.layers.{layer}.fc2.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}}, + "model.layers.{layer}.self_attn_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}}, + "model.layers.{layer}.self_attn_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}}, + "model.layers.{layer}.final_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "scale"}}, + "model.layers.{layer}.final_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "offset"}} + } +} diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py new file mode 100644 index 00000000..d097675f --- /dev/null +++ b/torch_lazy_loader.py @@ -0,0 +1,250 @@ +''' +This file is AGPL-licensed. + +Some of the code in this file is copied from PyTorch. + +The license for PyTorch is shown below: + +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + and IDIAP Research Institute nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. +''' + + +import contextlib +from functools import reduce +import itertools +import zipfile +import pickle +import torch +from torch.nn import Module +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union + + +_EXTRA_STATE_KEY_SUFFIX = '_extra_state' + + +class LazyTensor: + def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None): + self.storage_type = storage_type + self.key = key + self.location = location + self.seek_offset = seek_offset + self.shape = shape + self.stride = stride + self.requires_grad = requires_grad + self.backward_hooks = backward_hooks + + def __view(self, f: Callable): + return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, seek_offset={f(self.seek_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})" + + def __repr__(self): + return self.__view(repr) + + def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None) -> torch.Tensor: + size = reduce(lambda x, y: x * y, self.shape, 1) + dtype = self.storage_type(0).dtype + nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) + if isinstance(checkpoint, zipfile.ZipFile): + f = checkpoint.open(f"archive/data/{self.key}", "r") + f.seek(self.seek_offset) + else: + f = checkpoint + try: + storage = self.storage_type.from_buffer(f.read(nbytes), "little") + finally: + if isinstance(checkpoint, zipfile.ZipFile): + f.close() + storage = torch.serialization._get_restore_location(map_location)(storage, self.location) + tensor = torch.tensor([], dtype=storage.dtype, device=storage.device) + tensor.set_(storage, 0, self.shape, self.stride) + tensor.requires_grad = self.requires_grad + tensor._backward_hooks = self.backward_hooks + return tensor + + +class _LazyUnpickler(pickle.Unpickler): + lazy_loaded_storages: Dict[str, LazyTensor] + + def __init__(self, *args, **kwargs): + self.lazy_loaded_storages = {} + return super().__init__(*args, **kwargs) + + def forced_persistent_load(self, saved_id): + assert isinstance(saved_id, tuple) + typename = saved_id[0] + assert typename == "storage", f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" + storage_type, key, location, _ = saved_id[1:] + return LazyTensor(storage_type, key, location) + + def load(self, *args, **kwargs): + self.persistent_load = self.forced_persistent_load + retval = super().load(*args, **kwargs) + self.lazy_loaded_storages = {} + return retval + + +def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride): + lazy_storage.shape = shape + lazy_storage.stride = stride + dtype = lazy_storage.storage_type(0).dtype + lazy_storage.seek_offset = storage_offset if dtype is torch.bool else storage_offset * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) + return lazy_storage + + +# Modified version of https://github.com/pytorch/pytorch/blob/v1.11.0-rc4/torch/nn/modules/module.py#L1346-L1438 +def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in local_state.items(): + key = prefix + name + if key in state_dict: + input_param = state_dict[key] + if not torch.overrides.is_tensor_like(input_param): + error_msgs.append('While copying the parameter named "{}", ' + 'expected torch.Tensor or Tensor-like object from checkpoint but ' + 'received {}' + .format(key, type(input_param))) + continue + + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = torch.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + + if not is_param_lazy and input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' + 'the shape in current model is {}.' + .format(key, input_param.shape, param.shape)) + continue + try: + with torch.no_grad(): + #param.copy_(input_param) + new_param = torch.nn.Parameter(input_param, requires_grad=param.requires_grad) # This line is new + if name in self._parameters: # This line is new + self._parameters[name] = new_param # This line is new + if name in persistent_buffers: # This line is new + self._buffers[name] = new_param # This line is new + except Exception as ex: + error_msgs.append('While copying the parameter named "{}", ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}, ' + 'an exception occurred : {}.' + .format(key, param.size(), input_param.size(), ex.args)) + elif strict: + missing_keys.append(key) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix):] + input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child + if input_name not in self._modules and input_name not in local_state: + unexpected_keys.append(key) + + +@contextlib.contextmanager +def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False): + if not enable: + yield False + return + + try: + old_unpickler = pickle.Unpickler + pickle.Unpickler = _LazyUnpickler + + old_rebuild_tensor = torch._utils._rebuild_tensor + torch._utils._rebuild_tensor = _rebuild_tensor + + old_torch_load = torch.load + + def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args): + retval = old_torch_load(f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) + if callback is not None: + callback(retval, f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) + return retval + + torch.load = torch_load + + if dematerialized_modules: + old_linear_init = torch.nn.Linear.__init__ + old_embedding_init = torch.nn.Embedding.__init__ + old_layernorm_init = torch.nn.LayerNorm.__init__ + + def linear_init(self, *args, device=None, **kwargs): + return old_linear_init(self, *args, device="meta", **kwargs) + + def embedding_init(self, *args, device=None, **kwargs): + return old_embedding_init(self, *args, device="meta", **kwargs) + + def layernorm_init(self, *args, device=None, **kwargs): + return old_layernorm_init(self, *args, device="meta", **kwargs) + + torch.nn.Linear.__init__ = linear_init + torch.nn.Embedding.__init__ = embedding_init + torch.nn.LayerNorm.__init__ = layernorm_init + old_load_from_state_dict = torch.nn.Module._load_from_state_dict + torch.nn.Module._load_from_state_dict = _load_from_state_dict + + yield True + + finally: + pickle.Unpickler = old_unpickler + torch._utils._rebuild_tensor = old_rebuild_tensor + torch.load = old_torch_load + if dematerialized_modules: + torch.nn.Linear.__init__ = old_linear_init + torch.nn.Embedding.__init__ = old_embedding_init + torch.nn.LayerNorm.__init__ = old_layernorm_init + torch.nn.Module._load_from_state_dict = old_load_from_state_dict diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index a78f93f2..6f5500b7 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -32,6 +32,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar import progressbar import time import os +import sys +import json +import zipfile import requests import random import jax @@ -41,9 +44,10 @@ import jax.numpy as jnp import numpy as np import optax import haiku as hk -import transformers +from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM from mesh_transformer.checkpoint import read_ckpt_lowmem -from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard +from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard, PlaceholderTensor +from mesh_transformer.util import to_bf16 params: Dict[str, Any] = {} @@ -776,7 +780,26 @@ def infer_static( return samples -def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) -> None: +def reshard_reverse(x, total_shards, old_shape): + assert len(x.shape) != 1 + if len(x.shape) == 2: + if old_shape[1] == x.shape[1]: + out = x[0:1].tile((total_shards, 1)) + else: + out = x.reshape(old_shape) + elif len(x.shape) == 3: + if x.shape[0] * x.shape[2] == old_shape[2]: + out = x.reshape(old_shape) + elif x.shape[0] * x.shape[1] == old_shape[1]: + out = x.reshape((old_shape[1], old_shape[0], old_shape[2])).permute((1, 0, 2)) + else: + assert False + else: + assert False + return out + + +def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None: global thread_resources_env, seq, tokenizer, network, params default_params = { @@ -795,6 +818,53 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) "tokenizer": "gpt2", } params = kwargs + + # Try to convert HF config.json to MTJ config + if hf_checkpoint: + spec_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "maps", vars.model_type + ".json") + if not os.path.isfile(spec_path): + raise NotImplementedError(f"Unsupported model type {repr(vars.model_type)}") + with open(spec_path) as f: + lazy_load_spec = json.load(f) + + if "mtj_compat" in lazy_load_spec: + params["compat"] = lazy_load_spec["mtj_compat"] + if "mtj_pe" in lazy_load_spec: + params["pe"] = lazy_load_spec["mtj_pe"] + for k, v in lazy_load_spec.get("mtj_config_map", {}).items(): + if type(v) is not list: + params[k] = params[v] + continue + for i in range(len(v)): + if i == len(v) - 1: + params[k] = v[i] + elif v[i] in params: + params[k] = params[v[i]] + break + + params["n_vocab"] = params["vocab_size"] + + if "activation_function" in params: + params["activation"] = params["activation_function"] + + # Both the number of attention heads in the model and the embedding + # dimension of the model need to be divisible by the number of TPU cores + # that we use, and JAX also requires the number of TPU cores used to be + # an even number if we're using more than one core, so logically we try + # to pick the largest possible even number of TPU cores such that the + # number of attention heads and embedding dimension are both divisible + # by the number of TPU cores, and fall back to one core if an even + # number of TPU cores is not possible. + for c in (8, 6, 4, 2, 1): + if 0 == params["n_heads"] % c == params["d_model"] % c: + params["cores_per_replica"] = c + break + + # The vocabulary size of the model also has to be divisible by the + # number of TPU cores, so we pad the vocabulary with the minimum + # possible number of dummy tokens such that it's divisible. + params["n_vocab_padding"] = -(params["n_vocab"] % -params["cores_per_replica"]) + if "compat" in params: default_params["compat"] = params["compat"] if default_params["compat"] == "fairseq_lm": @@ -804,10 +874,11 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) params[param] = default_params[param] # Load tokenizer - if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")): - raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'") - tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"]) - tokenizer = tokenizer_class.from_pretrained(params["tokenizer"]) + if not hf_checkpoint: + if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")): + raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'") + tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"]) + tokenizer = tokenizer_class.from_pretrained(params["tokenizer"]) # Disable JAX warnings about these two functions having been renamed jax.host_count = jax.process_count @@ -844,5 +915,147 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) path += "/" network = PenalizingCausalTransformer(params, dematerialized=True) - network.state = read_ckpt_lowmem(network.state, path, devices.shape[1]) + + if not hf_checkpoint: + network.state = read_ckpt_lowmem(network.state, path, devices.shape[1]) + network.state = network.move_xmap(network.state, np.zeros(cores_per_replica)) + return + + # Convert from HF checkpoint + + move_xmap = jax.experimental.maps.xmap( + fun=lambda x, _: to_bf16(x), + in_axes=(["shard", ...], ["batch", ...]), + out_axes=["shard", ...], + axis_resources={'shard': 'mp', 'batch': 'dp'} + ) + + model_spec = {} + for key, spec in lazy_load_spec.get("static_weights", {}).items(): + if spec.get("mtj") is not None: + model_spec[key] = spec["mtj"].copy() + model_spec[key]["module"] = "causal_transformer_shard/~/" + model_spec[key]["module"] + for _key, spec in lazy_load_spec.get("layer_weights", {}).items(): + for layer in range(params["layers"]): + if spec.get("mtj") is not None: + key = _key.format(layer=layer) + model_spec[key] = spec["mtj"].copy() + model_spec[key]["module"] = "causal_transformer_shard/~/" + model_spec[key]["module"].format(layer=layer) + + import torch_lazy_loader + import torch + from tqdm import tqdm + + def callback(model_dict, f, **_): + with zipfile.ZipFile(f, "r") as z: + try: + last_storage_key = None + f = None + print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n") + for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"): + + # Some model weights are used by transformers but not by MTJ. + # We have to materialize these weights anyways because + # transformers will throw a tantrum otherwise. To attain + # the least possible memory usage, we create them as meta + # tensors, which don't take up any actual CPU or TPU memory. + if key not in model_spec: + model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].storage_type(0).dtype, device="meta") + continue + + storage_key = model_dict[key].key + if storage_key != last_storage_key: + last_storage_key = storage_key + if isinstance(f, zipfile.ZipExtFile): + f.close() + f = z.open(f"archive/data/{storage_key}") + current_offset = f.tell() + if current_offset != model_dict[key].seek_offset: + f.seek(model_dict[key].seek_offset - current_offset, 1) + spec = model_spec[key] + transforms = set(spec.get("transforms", ())) + if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor): + error = f"Duplicate key {repr(key)}" + print("\n\nERROR: " + error, file=sys.stderr) + raise RuntimeError(error) + tensor = model_dict[key].materialize(f, map_location="cpu") + model_dict[key] = tensor.to("meta") + + # MTJ requires certain mathematical operations to be performed + # on tensors in order for them to be in the correct format + if "divide_by_shards" in transforms: + tensor /= params["cores_per_replica"] + if "vocab_pad" in transforms: + tensor = torch.nn.functional.pad(tensor, (0, 0, 0, params["n_vocab_padding"])) + if "no_transpose" not in transforms: + tensor = tensor.T + tensor.unsqueeze_(0) + + # Shard the tensor so that parts of the tensor can be used + # on different TPU cores + network.state["params"][spec["module"]][spec["param"]] = move_xmap( + jnp.array( + reshard_reverse( + tensor, + params["cores_per_replica"], + network.state["params"][spec["module"]][spec["param"]].shape, + ), + dtype=jnp.bfloat16, + ), + np.empty(params["cores_per_replica"]), + ) + + # Check for tensors that MTJ needs that were not provided in the + # HF model + for mk, mv in network.state["params"].items(): + for pk, pv in mv.items(): + if isinstance(pv, PlaceholderTensor): + # The transformers GPT-J models apparently do not + # have embedding bias, whereas MTJ GPT-J models do, + # so we have to supplement an embedding bias tensor + # by creating a tensor with the necessary shape, filled + # with zeros. + if mk == "causal_transformer_shard/~/embedding_shard/~/linear" and pk == "b": + mv[pk] = move_xmap(jnp.zeros(mv[pk].shape, dtype=jnp.bfloat16), np.empty(params["cores_per_replica"])) + + else: + error = f"{mk} {pk} could not be found in the model checkpoint" + print("\n\nERROR: " + error, file=sys.stderr) + raise RuntimeError(error) + finally: + if isinstance(f, zipfile.ZipExtFile): + f.close() + + if os.path.isdir(vars.model.replace('/', '_')): + import shutil + shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_'))) + with torch_lazy_loader.use_lazy_torch_load(callback=callback, dematerialized_modules=True): + if(os.path.isdir(vars.custmodpth)): + try: + tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache") + except ValueError as e: + tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache") + try: + model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache") + except ValueError as e: + model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache") + elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))): + try: + tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") + except ValueError as e: + tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") + try: + model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) + except ValueError as e: + model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) + else: + try: + tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache") + except ValueError as e: + tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache") + try: + model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache") + except ValueError as e: + model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache") + network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))