Merge pull request #93 from VE-FORBRYDERNE/lazy-loader

Lazy loader
This commit is contained in:
henk717 2022-03-05 20:32:31 +01:00 committed by GitHub
commit 77cc2ee789
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 713 additions and 63 deletions

View File

@ -26,6 +26,8 @@ import traceback
import threading
import markdown
import bleach
import itertools
import bisect
from collections.abc import Iterable
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
@ -250,6 +252,8 @@ class vars:
newlinemode = "n"
quiet = False # If set will suppress any story text from being printed to the console (will only be seen on the client web page)
debug = False # If set to true, will send debug information to the client for display
lazy_load = True # Whether or not to use torch_lazy_loader.py for transformers models in order to reduce CPU memory usage
use_colab_tpu = os.environ.get("COLAB_TPU_ADDR", "") != "" # Whether or not we're in a Colab TPU instance and are going to use the TPU rather than the CPU
utils.vars = vars
@ -337,10 +341,10 @@ def device_list(n_layers, primary=None, selected=None):
sep_color = colors.YELLOW
print(f"{row_color} {' '*9} N/A {sep_color}|{row_color} {n_layers:3} {sep_color}|{row_color} (CPU){colors.END}")
def device_config(model):
def device_config(config):
global breakmodel, generator
import breakmodel
n_layers = model.config.num_layers if hasattr(model.config, "num_layers") else model.config.n_layer
n_layers = config.num_layers if hasattr(config, "num_layers") else config.n_layer
if(args.breakmodel_gpulayers is not None):
try:
breakmodel.gpu_blocks = list(map(int, args.breakmodel_gpulayers.split(',')))
@ -413,22 +417,30 @@ def device_config(model):
# If all layers are on the same device, use the old GPU generation mode
while(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] == 0):
breakmodel.gpu_blocks.pop()
if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, model.config.num_layers if hasattr(model.config, "num_layers") else model.config.n_layer)):
if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, config.num_layers if hasattr(config, "num_layers") else config.n_layer)):
vars.breakmodel = False
vars.usegpu = True
vars.gpu_device = len(breakmodel.gpu_blocks)-1
model = model.half().to(vars.gpu_device)
generator = model.generate
return
if(not breakmodel.gpu_blocks):
print("Nothing assigned to a GPU, reverting to CPU only mode")
vars.breakmodel = False
vars.usegpu = False
model = model.to('cpu').float()
return
def move_model_to_devices(model):
global generator
if(not vars.breakmodel):
if(vars.usegpu):
model = model.half().to(vars.gpu_device)
else:
model = model.to('cpu').float()
generator = model.generate
return
model.half().to('cpu')
model.half()
gc.collect()
if(hasattr(model, "transformer")):
model.transformer.wte.to(breakmodel.primary_device)
@ -684,7 +696,7 @@ def spRequest(filename):
vars.sp_length = tensor.shape[-2]
vars.spmeta["n_tokens"] = vars.sp_length
if(vars.model in ("TPUMeshTransformerGPTJ",)):
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)):
rows = tensor.shape[0]
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
@ -756,6 +768,9 @@ if args.ngrok:
if args.host:
vars.host = True;
if args.cpu:
vars.use_colab_tpu = False
vars.smandelete = vars.host == args.override_delete
vars.smanrename = vars.host == args.override_rename
@ -772,7 +787,7 @@ else:
getModelSelection(mainmenu)
# If transformers model was selected & GPU available, ask to use CPU or GPU
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
if(vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
vars.allowsp = True
# Test for GPU support
import torch
@ -811,6 +826,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
elif(vars.model_type == "not_found"):
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
vars.model_type = "gpt_neo"
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
loadmodelsettings()
loadsettings()
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
@ -1003,7 +1020,7 @@ socketio = SocketIO(app, async_method="eventlet")
print("{0}OK!{1}".format(colors.GREEN, colors.END))
# Start transformers and create pipeline
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
if(not vars.noai):
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
@ -1015,6 +1032,71 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
import transformers.generation_utils
from transformers import __version__ as transformers_version
# Lazy loader
import torch_lazy_loader
def get_lazy_load_callback(n_layers, convert_to_float16=True):
if not vars.lazy_load:
return
from tqdm import tqdm
if "breakmodel" in globals():
gpu_blocks = breakmodel.gpu_blocks
ram_blocks = ram_blocks = n_layers - sum(gpu_blocks)
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
else:
ram_blocks = gpu_blocks = cumulative_gpu_blocks = None
def lazy_load_callback(model_dict, f, **_):
device_map = {}
for _key, spec in lazy_load_spec.get("layer_weights", {}).items():
for layer in range(n_layers):
key = _key.format(layer=layer)
if key not in model_dict:
continue
device = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" if not vars.hascuda or not vars.breakmodel or layer < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks)
device_map[key] = device
for key, value in model_dict.items():
if isinstance(value, torch_lazy_loader.LazyTensor) and key not in device_map:
device_map[key] = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu"
with zipfile.ZipFile(f, "r") as z:
try:
last_storage_key = None
f = None
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
storage_key = model_dict[key].key
if storage_key != last_storage_key:
last_storage_key = storage_key
if isinstance(f, zipfile.ZipExtFile):
f.close()
f = z.open(f"archive/data/{storage_key}")
current_offset = f.tell()
if current_offset != model_dict[key].seek_offset:
f.seek(model_dict[key].seek_offset - current_offset, 1)
device = device_map[key]
#print(f"Transferring <{key}> to {'(CPU)' if device == 'cpu' else '[device ' + str(device) + ']'} ... ", end="", flush=True)
model_dict[key] = model_dict[key].materialize(f, map_location="cpu")
if convert_to_float16 and vars.hascuda and (vars.breakmodel or vars.usegpu) and model_dict[key].dtype is torch.float32:
model_dict[key] = model_dict[key].to(torch.float16)
model_dict[key] = model_dict[key].to(device)
#print("OK", flush=True)
finally:
if isinstance(f, zipfile.ZipExtFile):
f.close()
return lazy_load_callback
lazy_load_config_path = os.path.join(path.dirname(path.realpath(__file__)), "maps", vars.model_type + ".json")
if(vars.lazy_load and "model_config" in globals() and os.path.isfile(lazy_load_config_path)):
with open(lazy_load_config_path) as f:
lazy_load_spec = json.load(f)
else:
vars.lazy_load = False
# Some versions of transformers 4.17.0.dev0 are affected by
# https://github.com/huggingface/transformers/issues/15736
# This is a workaround for those versions of transformers.
@ -1250,6 +1332,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
# If custom GPT2 model was chosen
if(vars.model == "GPT2Custom"):
vars.lazy_load = False
model_config = open(vars.custmodpth + "/config.json", "r")
js = json.load(model_config)
with(maybe_use_float16()):
@ -1271,6 +1354,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
# feature yet
if(vars.model_type == "gpt2"):
lowmem = {}
# If we're using torch_lazy_loader, we need to get breakmodel config
# early so that it knows where to load the individual model tensors
if(vars.lazy_load and vars.hascuda and vars.breakmodel):
device_config(model_config)
# Download model from Huggingface if it does not exist, otherwise load locally
@ -1278,43 +1366,43 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
if os.path.isdir(vars.model.replace('/', '_')):
import shutil
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
if(os.path.isdir(vars.custmodpth)):
with(maybe_use_float16()):
try:
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem)
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem)
elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))):
with(maybe_use_float16()):
try:
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
else:
try:
tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache")
with(maybe_use_float16()):
with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(model_config.num_layers if hasattr(model_config, "num_layers") else model_config.n_layer), dematerialized_modules=True):
if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
lowmem = {}
if(os.path.isdir(vars.custmodpth)):
try:
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem)
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem)
elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))):
try:
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
else:
try:
tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache", **lowmem)
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache", **lowmem)
if not args.colab:
import shutil
model = model.half()
model.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
shutil.rmtree("cache/")
if not args.colab:
import shutil
model = model.half()
model.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
shutil.rmtree("cache/")
if(vars.hascuda):
if(vars.usegpu):
@ -1323,7 +1411,9 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
generator = model.generate
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
vars.modeldim = get_hidden_size_from_model(model)
device_config(model)
if(not vars.lazy_load):
device_config(model.config)
move_model_to_devices(model)
else:
model = model.to('cpu').float()
vars.modeldim = get_hidden_size_from_model(model)
@ -1440,9 +1530,9 @@ else:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
loadsettings()
# Load the TPU backend if requested
elif(vars.model == "TPUMeshTransformerGPTJ"):
elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"):
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
if not vars.custmodpth or not os.path.isdir(vars.custmodpth):
if vars.model == "TPUMeshTransformerGPTJ" and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)):
raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder")
import tpu_mtj_backend
tpu_mtj_backend.vars = vars
@ -1454,7 +1544,7 @@ else:
vars.allowsp = True
loadmodelsettings()
loadsettings()
tpu_mtj_backend.load_model(vars.custmodpth, **vars.modelconfig)
tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=vars.model != "TPUMeshTransformerGPTJ" and vars.use_colab_tpu, **vars.modelconfig)
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
tokenizer = tpu_mtj_backend.tokenizer
else:
@ -1985,7 +2075,7 @@ def lua_get_modeltype():
return "readonly"
if(vars.model in ("Colab", "OAI", "InferKit")):
return "api"
if(vars.model not in ("TPUMeshTransformerGPTJ",) and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))):
if(not vars.use_colab_tpu and vars.model not in ("TPUMeshTransformerGPTJ",) and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))):
hidden_size = get_hidden_size_from_model(model)
if(vars.model in ("gpt2",) or (vars.model_type == "gpt2" and hidden_size == 768)):
return "gpt2"
@ -2001,7 +2091,7 @@ def lua_get_modeltype():
return "gpt-neo-1.3B"
if(vars.model in ("EleutherAI/gpt-neo-2.7B",) or (vars.model_type == "gpt_neo" and hidden_size == 2560)):
return "gpt-neo-2.7B"
if(vars.model in ("EleutherAI/gpt-j-6B",) or (vars.model == "TPUMeshTransformerGPTJ" and tpu_mtj_backend.params["d_model"] == 4096) or (vars.model_type in ("gpt_neo", "gptj") and hidden_size == 4096)):
if(vars.model in ("EleutherAI/gpt-j-6B",) or ((vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ") and tpu_mtj_backend.params["d_model"] == 4096) or (vars.model_type in ("gpt_neo", "gptj") and hidden_size == 4096)):
return "gpt-j-6B"
return "unknown"
@ -2014,7 +2104,7 @@ def lua_get_modelbackend():
return "readonly"
if(vars.model in ("Colab", "OAI", "InferKit")):
return "api"
if(vars.model in ("TPUMeshTransformerGPTJ",)):
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)):
return "mtj"
return "transformers"
@ -2961,22 +3051,22 @@ def calcsubmit(txt):
if(vars.model != "InferKit"):
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions, submission=txt)
if(actionlen == 0):
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"):
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.model == "OAI"):
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.model == "TPUMeshTransformerGPTJ"):
elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"):
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
else:
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"):
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.model == "OAI"):
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.model == "TPUMeshTransformerGPTJ"):
elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"):
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
# For InferKit web API
@ -4987,7 +5077,7 @@ if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
file.close()
# Precompile TPU backend if required
if(vars.model in ("TPUMeshTransformerGPTJ",)):
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)):
soft_tokens = tpumtjgetsofttokens()
if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)):
threading.Thread(

32
maps/gpt_neo.json Normal file
View File

@ -0,0 +1,32 @@
{
"mtj_compat": "neo",
"mtj_pe": "fixed",
"mtj_config_map": {
"d_model": "hidden_size",
"n_heads": "num_heads",
"layers": "num_layers"
},
"static_weights": {
"transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}},
"transformer.wpe.weight": {"mtj": {"module": "embedding_shard", "param": "pos_embs", "transforms": ["no_transpose"]}},
"transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}},
"transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}}
},
"layer_weights": {
"transformer.h.{layer}.attn.attention.bias": {},
"transformer.h.{layer}.attn.attention.masked_bias": {},
"transformer.h.{layer}.attn.attention.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}},
"transformer.h.{layer}.attn.attention.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}},
"transformer.h.{layer}.attn.attention.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}},
"transformer.h.{layer}.attn.attention.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}},
"transformer.h.{layer}.attn.attention.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}},
"transformer.h.{layer}.mlp.c_fc.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}},
"transformer.h.{layer}.mlp.c_fc.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}},
"transformer.h.{layer}.mlp.c_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}},
"transformer.h.{layer}.mlp.c_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}},
"transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}},
"transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}},
"transformer.h.{layer}.ln_2.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "scale"}},
"transformer.h.{layer}.ln_2.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "offset"}}
}
}

32
maps/gptj.json Normal file
View File

@ -0,0 +1,32 @@
{
"mtj_compat": "j",
"mtj_pe": "rotary",
"mtj_config_map": {
"pe_rotary_dims": ["rotary_dim", 64],
"d_model": "n_embd",
"n_heads": "n_head",
"layers": "n_layer"
},
"static_weights": {
"transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}},
"transformer.wte.bias": {"mtj": {"module": "embedding_shard/~/linear", "param": "b"}},
"transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}},
"transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}},
"lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "transforms": ["vocab_pad"]}},
"lm_head.bias": {"mtj": {"module": "projection_shard/~/linear", "param": "b"}}
},
"layer_weights": {
"transformer.h.{layer}.attn.bias": {},
"transformer.h.{layer}.attn.masked_bias": {},
"transformer.h.{layer}.attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}},
"transformer.h.{layer}.attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}},
"transformer.h.{layer}.attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}},
"transformer.h.{layer}.attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}},
"transformer.h.{layer}.mlp.fc_in.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}},
"transformer.h.{layer}.mlp.fc_in.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}},
"transformer.h.{layer}.mlp.fc_out.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}},
"transformer.h.{layer}.mlp.fc_out.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}},
"transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}},
"transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}}
}
}

33
maps/xglm.json Normal file
View File

@ -0,0 +1,33 @@
{
"mtj_compat": "fairseq_lm",
"mtj_pe": "fairseq_sinusoidal",
"mtj_config_map": {
"d_model": "d_model",
"n_heads": "attention_heads",
"layers": "num_layers"
},
"static_weights": {
"model.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}},
"model.layer_norm.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}},
"model.layer_norm.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}},
"lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "transforms": ["vocab_pad"]}}
},
"layer_weights": {
"model.layers.{layer}.self_attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}},
"model.layers.{layer}.self_attn.q_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear", "param": "b"}},
"model.layers.{layer}.self_attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}},
"model.layers.{layer}.self_attn.v_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "b"}},
"model.layers.{layer}.self_attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}},
"model.layers.{layer}.self_attn.k_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "b"}},
"model.layers.{layer}.self_attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}},
"model.layers.{layer}.self_attn.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}},
"model.layers.{layer}.fc1.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}},
"model.layers.{layer}.fc1.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}},
"model.layers.{layer}.fc2.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}},
"model.layers.{layer}.fc2.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}},
"model.layers.{layer}.self_attn_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}},
"model.layers.{layer}.self_attn_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}},
"model.layers.{layer}.final_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "scale"}},
"model.layers.{layer}.final_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "offset"}}
}
}

250
torch_lazy_loader.py Normal file
View File

@ -0,0 +1,250 @@
'''
This file is AGPL-licensed.
Some of the code in this file is copied from PyTorch.
The license for PyTorch is shown below:
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
Copyright (c) 2011-2013 NYU (Clement Farabet)
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
and IDIAP Research Institute nor the names of its contributors may be
used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
'''
import contextlib
from functools import reduce
import itertools
import zipfile
import pickle
import torch
from torch.nn import Module
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
class LazyTensor:
def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None):
self.storage_type = storage_type
self.key = key
self.location = location
self.seek_offset = seek_offset
self.shape = shape
self.stride = stride
self.requires_grad = requires_grad
self.backward_hooks = backward_hooks
def __view(self, f: Callable):
return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, seek_offset={f(self.seek_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})"
def __repr__(self):
return self.__view(repr)
def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None) -> torch.Tensor:
size = reduce(lambda x, y: x * y, self.shape, 1)
dtype = self.storage_type(0).dtype
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
if isinstance(checkpoint, zipfile.ZipFile):
f = checkpoint.open(f"archive/data/{self.key}", "r")
f.seek(self.seek_offset)
else:
f = checkpoint
try:
storage = self.storage_type.from_buffer(f.read(nbytes), "little")
finally:
if isinstance(checkpoint, zipfile.ZipFile):
f.close()
storage = torch.serialization._get_restore_location(map_location)(storage, self.location)
tensor = torch.tensor([], dtype=storage.dtype, device=storage.device)
tensor.set_(storage, 0, self.shape, self.stride)
tensor.requires_grad = self.requires_grad
tensor._backward_hooks = self.backward_hooks
return tensor
class _LazyUnpickler(pickle.Unpickler):
lazy_loaded_storages: Dict[str, LazyTensor]
def __init__(self, *args, **kwargs):
self.lazy_loaded_storages = {}
return super().__init__(*args, **kwargs)
def forced_persistent_load(self, saved_id):
assert isinstance(saved_id, tuple)
typename = saved_id[0]
assert typename == "storage", f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
storage_type, key, location, _ = saved_id[1:]
return LazyTensor(storage_type, key, location)
def load(self, *args, **kwargs):
self.persistent_load = self.forced_persistent_load
retval = super().load(*args, **kwargs)
self.lazy_loaded_storages = {}
return retval
def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride):
lazy_storage.shape = shape
lazy_storage.stride = stride
dtype = lazy_storage.storage_type(0).dtype
lazy_storage.seek_offset = storage_offset if dtype is torch.bool else storage_offset * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
return lazy_storage
# Modified version of https://github.com/pytorch/pytorch/blob/v1.11.0-rc4/torch/nn/modules/module.py#L1346-L1438
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
if not torch.overrides.is_tensor_like(input_param):
error_msgs.append('While copying the parameter named "{}", '
'expected torch.Tensor or Tensor-like object from checkpoint but '
'received {}'
.format(key, type(input_param)))
continue
# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks
# in such case, it will error when accessing the .shape attribute.
is_param_lazy = torch.nn.parameter.is_lazy(param)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]
if not is_param_lazy and input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'
.format(key, input_param.shape, param.shape))
continue
try:
with torch.no_grad():
#param.copy_(input_param)
new_param = torch.nn.Parameter(input_param, requires_grad=param.requires_grad) # This line is new
if name in self._parameters: # This line is new
self._parameters[name] = new_param # This line is new
if name in persistent_buffers: # This line is new
self._buffers[name] = new_param # This line is new
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'
.format(key, param.size(), input_param.size(), ex.args))
elif strict:
missing_keys.append(key)
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
missing_keys.append(extra_state_key)
elif strict and (extra_state_key in state_dict):
unexpected_keys.append(extra_state_key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix):]
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)
@contextlib.contextmanager
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False):
if not enable:
yield False
return
try:
old_unpickler = pickle.Unpickler
pickle.Unpickler = _LazyUnpickler
old_rebuild_tensor = torch._utils._rebuild_tensor
torch._utils._rebuild_tensor = _rebuild_tensor
old_torch_load = torch.load
def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
retval = old_torch_load(f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args)
if callback is not None:
callback(retval, f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args)
return retval
torch.load = torch_load
if dematerialized_modules:
old_linear_init = torch.nn.Linear.__init__
old_embedding_init = torch.nn.Embedding.__init__
old_layernorm_init = torch.nn.LayerNorm.__init__
def linear_init(self, *args, device=None, **kwargs):
return old_linear_init(self, *args, device="meta", **kwargs)
def embedding_init(self, *args, device=None, **kwargs):
return old_embedding_init(self, *args, device="meta", **kwargs)
def layernorm_init(self, *args, device=None, **kwargs):
return old_layernorm_init(self, *args, device="meta", **kwargs)
torch.nn.Linear.__init__ = linear_init
torch.nn.Embedding.__init__ = embedding_init
torch.nn.LayerNorm.__init__ = layernorm_init
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
torch.nn.Module._load_from_state_dict = _load_from_state_dict
yield True
finally:
pickle.Unpickler = old_unpickler
torch._utils._rebuild_tensor = old_rebuild_tensor
torch.load = old_torch_load
if dematerialized_modules:
torch.nn.Linear.__init__ = old_linear_init
torch.nn.Embedding.__init__ = old_embedding_init
torch.nn.LayerNorm.__init__ = old_layernorm_init
torch.nn.Module._load_from_state_dict = old_load_from_state_dict

View File

@ -32,6 +32,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
import progressbar
import time
import os
import sys
import json
import zipfile
import requests
import random
import jax
@ -41,9 +44,10 @@ import jax.numpy as jnp
import numpy as np
import optax
import haiku as hk
import transformers
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
from mesh_transformer.checkpoint import read_ckpt_lowmem
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard, PlaceholderTensor
from mesh_transformer.util import to_bf16
params: Dict[str, Any] = {}
@ -776,7 +780,26 @@ def infer_static(
return samples
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) -> None:
def reshard_reverse(x, total_shards, old_shape):
assert len(x.shape) != 1
if len(x.shape) == 2:
if old_shape[1] == x.shape[1]:
out = x[0:1].tile((total_shards, 1))
else:
out = x.reshape(old_shape)
elif len(x.shape) == 3:
if x.shape[0] * x.shape[2] == old_shape[2]:
out = x.reshape(old_shape)
elif x.shape[0] * x.shape[1] == old_shape[1]:
out = x.reshape((old_shape[1], old_shape[0], old_shape[2])).permute((1, 0, 2))
else:
assert False
else:
assert False
return out
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
global thread_resources_env, seq, tokenizer, network, params
default_params = {
@ -795,6 +818,53 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
"tokenizer": "gpt2",
}
params = kwargs
# Try to convert HF config.json to MTJ config
if hf_checkpoint:
spec_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "maps", vars.model_type + ".json")
if not os.path.isfile(spec_path):
raise NotImplementedError(f"Unsupported model type {repr(vars.model_type)}")
with open(spec_path) as f:
lazy_load_spec = json.load(f)
if "mtj_compat" in lazy_load_spec:
params["compat"] = lazy_load_spec["mtj_compat"]
if "mtj_pe" in lazy_load_spec:
params["pe"] = lazy_load_spec["mtj_pe"]
for k, v in lazy_load_spec.get("mtj_config_map", {}).items():
if type(v) is not list:
params[k] = params[v]
continue
for i in range(len(v)):
if i == len(v) - 1:
params[k] = v[i]
elif v[i] in params:
params[k] = params[v[i]]
break
params["n_vocab"] = params["vocab_size"]
if "activation_function" in params:
params["activation"] = params["activation_function"]
# Both the number of attention heads in the model and the embedding
# dimension of the model need to be divisible by the number of TPU cores
# that we use, and JAX also requires the number of TPU cores used to be
# an even number if we're using more than one core, so logically we try
# to pick the largest possible even number of TPU cores such that the
# number of attention heads and embedding dimension are both divisible
# by the number of TPU cores, and fall back to one core if an even
# number of TPU cores is not possible.
for c in (8, 6, 4, 2, 1):
if 0 == params["n_heads"] % c == params["d_model"] % c:
params["cores_per_replica"] = c
break
# The vocabulary size of the model also has to be divisible by the
# number of TPU cores, so we pad the vocabulary with the minimum
# possible number of dummy tokens such that it's divisible.
params["n_vocab_padding"] = -(params["n_vocab"] % -params["cores_per_replica"])
if "compat" in params:
default_params["compat"] = params["compat"]
if default_params["compat"] == "fairseq_lm":
@ -804,10 +874,11 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
params[param] = default_params[param]
# Load tokenizer
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"])
tokenizer = tokenizer_class.from_pretrained(params["tokenizer"])
if not hf_checkpoint:
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"])
tokenizer = tokenizer_class.from_pretrained(params["tokenizer"])
# Disable JAX warnings about these two functions having been renamed
jax.host_count = jax.process_count
@ -844,5 +915,147 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
path += "/"
network = PenalizingCausalTransformer(params, dematerialized=True)
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
if not hf_checkpoint:
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
return
# Convert from HF checkpoint
move_xmap = jax.experimental.maps.xmap(
fun=lambda x, _: to_bf16(x),
in_axes=(["shard", ...], ["batch", ...]),
out_axes=["shard", ...],
axis_resources={'shard': 'mp', 'batch': 'dp'}
)
model_spec = {}
for key, spec in lazy_load_spec.get("static_weights", {}).items():
if spec.get("mtj") is not None:
model_spec[key] = spec["mtj"].copy()
model_spec[key]["module"] = "causal_transformer_shard/~/" + model_spec[key]["module"]
for _key, spec in lazy_load_spec.get("layer_weights", {}).items():
for layer in range(params["layers"]):
if spec.get("mtj") is not None:
key = _key.format(layer=layer)
model_spec[key] = spec["mtj"].copy()
model_spec[key]["module"] = "causal_transformer_shard/~/" + model_spec[key]["module"].format(layer=layer)
import torch_lazy_loader
import torch
from tqdm import tqdm
def callback(model_dict, f, **_):
with zipfile.ZipFile(f, "r") as z:
try:
last_storage_key = None
f = None
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
# Some model weights are used by transformers but not by MTJ.
# We have to materialize these weights anyways because
# transformers will throw a tantrum otherwise. To attain
# the least possible memory usage, we create them as meta
# tensors, which don't take up any actual CPU or TPU memory.
if key not in model_spec:
model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].storage_type(0).dtype, device="meta")
continue
storage_key = model_dict[key].key
if storage_key != last_storage_key:
last_storage_key = storage_key
if isinstance(f, zipfile.ZipExtFile):
f.close()
f = z.open(f"archive/data/{storage_key}")
current_offset = f.tell()
if current_offset != model_dict[key].seek_offset:
f.seek(model_dict[key].seek_offset - current_offset, 1)
spec = model_spec[key]
transforms = set(spec.get("transforms", ()))
if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor):
error = f"Duplicate key {repr(key)}"
print("\n\nERROR: " + error, file=sys.stderr)
raise RuntimeError(error)
tensor = model_dict[key].materialize(f, map_location="cpu")
model_dict[key] = tensor.to("meta")
# MTJ requires certain mathematical operations to be performed
# on tensors in order for them to be in the correct format
if "divide_by_shards" in transforms:
tensor /= params["cores_per_replica"]
if "vocab_pad" in transforms:
tensor = torch.nn.functional.pad(tensor, (0, 0, 0, params["n_vocab_padding"]))
if "no_transpose" not in transforms:
tensor = tensor.T
tensor.unsqueeze_(0)
# Shard the tensor so that parts of the tensor can be used
# on different TPU cores
network.state["params"][spec["module"]][spec["param"]] = move_xmap(
jnp.array(
reshard_reverse(
tensor,
params["cores_per_replica"],
network.state["params"][spec["module"]][spec["param"]].shape,
),
dtype=jnp.bfloat16,
),
np.empty(params["cores_per_replica"]),
)
# Check for tensors that MTJ needs that were not provided in the
# HF model
for mk, mv in network.state["params"].items():
for pk, pv in mv.items():
if isinstance(pv, PlaceholderTensor):
# The transformers GPT-J models apparently do not
# have embedding bias, whereas MTJ GPT-J models do,
# so we have to supplement an embedding bias tensor
# by creating a tensor with the necessary shape, filled
# with zeros.
if mk == "causal_transformer_shard/~/embedding_shard/~/linear" and pk == "b":
mv[pk] = move_xmap(jnp.zeros(mv[pk].shape, dtype=jnp.bfloat16), np.empty(params["cores_per_replica"]))
else:
error = f"{mk} {pk} could not be found in the model checkpoint"
print("\n\nERROR: " + error, file=sys.stderr)
raise RuntimeError(error)
finally:
if isinstance(f, zipfile.ZipExtFile):
f.close()
if os.path.isdir(vars.model.replace('/', '_')):
import shutil
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
with torch_lazy_loader.use_lazy_torch_load(callback=callback, dematerialized_modules=True):
if(os.path.isdir(vars.custmodpth)):
try:
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache")
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache")
elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))):
try:
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
else:
try:
tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache")
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache")
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))