commit
77cc2ee789
198
aiserver.py
198
aiserver.py
|
@ -26,6 +26,8 @@ import traceback
|
|||
import threading
|
||||
import markdown
|
||||
import bleach
|
||||
import itertools
|
||||
import bisect
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
|
||||
|
||||
|
@ -250,6 +252,8 @@ class vars:
|
|||
newlinemode = "n"
|
||||
quiet = False # If set will suppress any story text from being printed to the console (will only be seen on the client web page)
|
||||
debug = False # If set to true, will send debug information to the client for display
|
||||
lazy_load = True # Whether or not to use torch_lazy_loader.py for transformers models in order to reduce CPU memory usage
|
||||
use_colab_tpu = os.environ.get("COLAB_TPU_ADDR", "") != "" # Whether or not we're in a Colab TPU instance and are going to use the TPU rather than the CPU
|
||||
|
||||
utils.vars = vars
|
||||
|
||||
|
@ -337,10 +341,10 @@ def device_list(n_layers, primary=None, selected=None):
|
|||
sep_color = colors.YELLOW
|
||||
print(f"{row_color} {' '*9} N/A {sep_color}|{row_color} {n_layers:3} {sep_color}|{row_color} (CPU){colors.END}")
|
||||
|
||||
def device_config(model):
|
||||
def device_config(config):
|
||||
global breakmodel, generator
|
||||
import breakmodel
|
||||
n_layers = model.config.num_layers if hasattr(model.config, "num_layers") else model.config.n_layer
|
||||
n_layers = config.num_layers if hasattr(config, "num_layers") else config.n_layer
|
||||
if(args.breakmodel_gpulayers is not None):
|
||||
try:
|
||||
breakmodel.gpu_blocks = list(map(int, args.breakmodel_gpulayers.split(',')))
|
||||
|
@ -413,22 +417,30 @@ def device_config(model):
|
|||
# If all layers are on the same device, use the old GPU generation mode
|
||||
while(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] == 0):
|
||||
breakmodel.gpu_blocks.pop()
|
||||
if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, model.config.num_layers if hasattr(model.config, "num_layers") else model.config.n_layer)):
|
||||
if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, config.num_layers if hasattr(config, "num_layers") else config.n_layer)):
|
||||
vars.breakmodel = False
|
||||
vars.usegpu = True
|
||||
vars.gpu_device = len(breakmodel.gpu_blocks)-1
|
||||
model = model.half().to(vars.gpu_device)
|
||||
generator = model.generate
|
||||
return
|
||||
|
||||
if(not breakmodel.gpu_blocks):
|
||||
print("Nothing assigned to a GPU, reverting to CPU only mode")
|
||||
vars.breakmodel = False
|
||||
vars.usegpu = False
|
||||
model = model.to('cpu').float()
|
||||
return
|
||||
|
||||
def move_model_to_devices(model):
|
||||
global generator
|
||||
|
||||
if(not vars.breakmodel):
|
||||
if(vars.usegpu):
|
||||
model = model.half().to(vars.gpu_device)
|
||||
else:
|
||||
model = model.to('cpu').float()
|
||||
generator = model.generate
|
||||
return
|
||||
model.half().to('cpu')
|
||||
|
||||
model.half()
|
||||
gc.collect()
|
||||
if(hasattr(model, "transformer")):
|
||||
model.transformer.wte.to(breakmodel.primary_device)
|
||||
|
@ -684,7 +696,7 @@ def spRequest(filename):
|
|||
vars.sp_length = tensor.shape[-2]
|
||||
vars.spmeta["n_tokens"] = vars.sp_length
|
||||
|
||||
if(vars.model in ("TPUMeshTransformerGPTJ",)):
|
||||
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)):
|
||||
rows = tensor.shape[0]
|
||||
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
|
||||
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
|
||||
|
@ -756,6 +768,9 @@ if args.ngrok:
|
|||
if args.host:
|
||||
vars.host = True;
|
||||
|
||||
if args.cpu:
|
||||
vars.use_colab_tpu = False
|
||||
|
||||
vars.smandelete = vars.host == args.override_delete
|
||||
vars.smanrename = vars.host == args.override_rename
|
||||
|
||||
|
@ -772,7 +787,7 @@ else:
|
|||
getModelSelection(mainmenu)
|
||||
|
||||
# If transformers model was selected & GPU available, ask to use CPU or GPU
|
||||
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||
if(vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||
vars.allowsp = True
|
||||
# Test for GPU support
|
||||
import torch
|
||||
|
@ -811,6 +826,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
elif(vars.model_type == "not_found"):
|
||||
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
|
||||
vars.model_type = "gpt_neo"
|
||||
|
||||
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||
loadmodelsettings()
|
||||
loadsettings()
|
||||
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
||||
|
@ -1003,7 +1020,7 @@ socketio = SocketIO(app, async_method="eventlet")
|
|||
print("{0}OK!{1}".format(colors.GREEN, colors.END))
|
||||
|
||||
# Start transformers and create pipeline
|
||||
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||
if(not vars.noai):
|
||||
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
|
||||
|
@ -1015,6 +1032,71 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
import transformers.generation_utils
|
||||
from transformers import __version__ as transformers_version
|
||||
|
||||
# Lazy loader
|
||||
import torch_lazy_loader
|
||||
def get_lazy_load_callback(n_layers, convert_to_float16=True):
|
||||
if not vars.lazy_load:
|
||||
return
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
if "breakmodel" in globals():
|
||||
gpu_blocks = breakmodel.gpu_blocks
|
||||
ram_blocks = ram_blocks = n_layers - sum(gpu_blocks)
|
||||
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
|
||||
else:
|
||||
ram_blocks = gpu_blocks = cumulative_gpu_blocks = None
|
||||
|
||||
def lazy_load_callback(model_dict, f, **_):
|
||||
device_map = {}
|
||||
|
||||
for _key, spec in lazy_load_spec.get("layer_weights", {}).items():
|
||||
for layer in range(n_layers):
|
||||
key = _key.format(layer=layer)
|
||||
if key not in model_dict:
|
||||
continue
|
||||
device = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" if not vars.hascuda or not vars.breakmodel or layer < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks)
|
||||
device_map[key] = device
|
||||
|
||||
for key, value in model_dict.items():
|
||||
if isinstance(value, torch_lazy_loader.LazyTensor) and key not in device_map:
|
||||
device_map[key] = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu"
|
||||
|
||||
with zipfile.ZipFile(f, "r") as z:
|
||||
try:
|
||||
last_storage_key = None
|
||||
f = None
|
||||
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
|
||||
storage_key = model_dict[key].key
|
||||
if storage_key != last_storage_key:
|
||||
last_storage_key = storage_key
|
||||
if isinstance(f, zipfile.ZipExtFile):
|
||||
f.close()
|
||||
f = z.open(f"archive/data/{storage_key}")
|
||||
current_offset = f.tell()
|
||||
if current_offset != model_dict[key].seek_offset:
|
||||
f.seek(model_dict[key].seek_offset - current_offset, 1)
|
||||
device = device_map[key]
|
||||
#print(f"Transferring <{key}> to {'(CPU)' if device == 'cpu' else '[device ' + str(device) + ']'} ... ", end="", flush=True)
|
||||
model_dict[key] = model_dict[key].materialize(f, map_location="cpu")
|
||||
if convert_to_float16 and vars.hascuda and (vars.breakmodel or vars.usegpu) and model_dict[key].dtype is torch.float32:
|
||||
model_dict[key] = model_dict[key].to(torch.float16)
|
||||
model_dict[key] = model_dict[key].to(device)
|
||||
#print("OK", flush=True)
|
||||
finally:
|
||||
if isinstance(f, zipfile.ZipExtFile):
|
||||
f.close()
|
||||
|
||||
return lazy_load_callback
|
||||
|
||||
lazy_load_config_path = os.path.join(path.dirname(path.realpath(__file__)), "maps", vars.model_type + ".json")
|
||||
if(vars.lazy_load and "model_config" in globals() and os.path.isfile(lazy_load_config_path)):
|
||||
with open(lazy_load_config_path) as f:
|
||||
lazy_load_spec = json.load(f)
|
||||
|
||||
else:
|
||||
vars.lazy_load = False
|
||||
|
||||
# Some versions of transformers 4.17.0.dev0 are affected by
|
||||
# https://github.com/huggingface/transformers/issues/15736
|
||||
# This is a workaround for those versions of transformers.
|
||||
|
@ -1250,6 +1332,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
|
||||
# If custom GPT2 model was chosen
|
||||
if(vars.model == "GPT2Custom"):
|
||||
vars.lazy_load = False
|
||||
model_config = open(vars.custmodpth + "/config.json", "r")
|
||||
js = json.load(model_config)
|
||||
with(maybe_use_float16()):
|
||||
|
@ -1272,49 +1355,54 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
if(vars.model_type == "gpt2"):
|
||||
lowmem = {}
|
||||
|
||||
# If we're using torch_lazy_loader, we need to get breakmodel config
|
||||
# early so that it knows where to load the individual model tensors
|
||||
if(vars.lazy_load and vars.hascuda and vars.breakmodel):
|
||||
device_config(model_config)
|
||||
|
||||
# Download model from Huggingface if it does not exist, otherwise load locally
|
||||
|
||||
#If we specify a model and it's in the root directory, we need to move it to the models directory (legacy folder structure to new)
|
||||
if os.path.isdir(vars.model.replace('/', '_')):
|
||||
import shutil
|
||||
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
|
||||
if(os.path.isdir(vars.custmodpth)):
|
||||
with(maybe_use_float16()):
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache")
|
||||
except ValueError as e:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache")
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem)
|
||||
except ValueError as e:
|
||||
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem)
|
||||
elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))):
|
||||
with(maybe_use_float16()):
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
|
||||
except ValueError as e:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
|
||||
except ValueError as e:
|
||||
model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
|
||||
else:
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache")
|
||||
except ValueError as e:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache")
|
||||
with(maybe_use_float16()):
|
||||
with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(model_config.num_layers if hasattr(model_config, "num_layers") else model_config.n_layer), dematerialized_modules=True):
|
||||
if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
|
||||
lowmem = {}
|
||||
if(os.path.isdir(vars.custmodpth)):
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache")
|
||||
except ValueError as e:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache")
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem)
|
||||
except ValueError as e:
|
||||
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem)
|
||||
elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))):
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
|
||||
except ValueError as e:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
|
||||
except ValueError as e:
|
||||
model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
|
||||
else:
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache")
|
||||
except ValueError as e:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache")
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache", **lowmem)
|
||||
except ValueError as e:
|
||||
model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache", **lowmem)
|
||||
|
||||
if not args.colab:
|
||||
import shutil
|
||||
model = model.half()
|
||||
model.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
|
||||
tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
|
||||
shutil.rmtree("cache/")
|
||||
if not args.colab:
|
||||
import shutil
|
||||
model = model.half()
|
||||
model.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
|
||||
tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
|
||||
shutil.rmtree("cache/")
|
||||
|
||||
if(vars.hascuda):
|
||||
if(vars.usegpu):
|
||||
|
@ -1323,7 +1411,9 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
generator = model.generate
|
||||
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
|
||||
vars.modeldim = get_hidden_size_from_model(model)
|
||||
device_config(model)
|
||||
if(not vars.lazy_load):
|
||||
device_config(model.config)
|
||||
move_model_to_devices(model)
|
||||
else:
|
||||
model = model.to('cpu').float()
|
||||
vars.modeldim = get_hidden_size_from_model(model)
|
||||
|
@ -1440,9 +1530,9 @@ else:
|
|||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
||||
loadsettings()
|
||||
# Load the TPU backend if requested
|
||||
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||
elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"):
|
||||
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||
if not vars.custmodpth or not os.path.isdir(vars.custmodpth):
|
||||
if vars.model == "TPUMeshTransformerGPTJ" and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)):
|
||||
raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder")
|
||||
import tpu_mtj_backend
|
||||
tpu_mtj_backend.vars = vars
|
||||
|
@ -1454,7 +1544,7 @@ else:
|
|||
vars.allowsp = True
|
||||
loadmodelsettings()
|
||||
loadsettings()
|
||||
tpu_mtj_backend.load_model(vars.custmodpth, **vars.modelconfig)
|
||||
tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=vars.model != "TPUMeshTransformerGPTJ" and vars.use_colab_tpu, **vars.modelconfig)
|
||||
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
||||
tokenizer = tpu_mtj_backend.tokenizer
|
||||
else:
|
||||
|
@ -1985,7 +2075,7 @@ def lua_get_modeltype():
|
|||
return "readonly"
|
||||
if(vars.model in ("Colab", "OAI", "InferKit")):
|
||||
return "api"
|
||||
if(vars.model not in ("TPUMeshTransformerGPTJ",) and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))):
|
||||
if(not vars.use_colab_tpu and vars.model not in ("TPUMeshTransformerGPTJ",) and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))):
|
||||
hidden_size = get_hidden_size_from_model(model)
|
||||
if(vars.model in ("gpt2",) or (vars.model_type == "gpt2" and hidden_size == 768)):
|
||||
return "gpt2"
|
||||
|
@ -2001,7 +2091,7 @@ def lua_get_modeltype():
|
|||
return "gpt-neo-1.3B"
|
||||
if(vars.model in ("EleutherAI/gpt-neo-2.7B",) or (vars.model_type == "gpt_neo" and hidden_size == 2560)):
|
||||
return "gpt-neo-2.7B"
|
||||
if(vars.model in ("EleutherAI/gpt-j-6B",) or (vars.model == "TPUMeshTransformerGPTJ" and tpu_mtj_backend.params["d_model"] == 4096) or (vars.model_type in ("gpt_neo", "gptj") and hidden_size == 4096)):
|
||||
if(vars.model in ("EleutherAI/gpt-j-6B",) or ((vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ") and tpu_mtj_backend.params["d_model"] == 4096) or (vars.model_type in ("gpt_neo", "gptj") and hidden_size == 4096)):
|
||||
return "gpt-j-6B"
|
||||
return "unknown"
|
||||
|
||||
|
@ -2014,7 +2104,7 @@ def lua_get_modelbackend():
|
|||
return "readonly"
|
||||
if(vars.model in ("Colab", "OAI", "InferKit")):
|
||||
return "api"
|
||||
if(vars.model in ("TPUMeshTransformerGPTJ",)):
|
||||
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)):
|
||||
return "mtj"
|
||||
return "transformers"
|
||||
|
||||
|
@ -2961,22 +3051,22 @@ def calcsubmit(txt):
|
|||
if(vars.model != "InferKit"):
|
||||
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions, submission=txt)
|
||||
if(actionlen == 0):
|
||||
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
|
||||
if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
|
||||
generate(subtxt, min, max, found_entries=found_entries)
|
||||
elif(vars.model == "Colab"):
|
||||
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||
elif(vars.model == "OAI"):
|
||||
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||
elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"):
|
||||
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
||||
else:
|
||||
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
|
||||
if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
|
||||
generate(subtxt, min, max, found_entries=found_entries)
|
||||
elif(vars.model == "Colab"):
|
||||
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||
elif(vars.model == "OAI"):
|
||||
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||
elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"):
|
||||
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
||||
|
||||
# For InferKit web API
|
||||
|
@ -4987,7 +5077,7 @@ if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
|
|||
file.close()
|
||||
|
||||
# Precompile TPU backend if required
|
||||
if(vars.model in ("TPUMeshTransformerGPTJ",)):
|
||||
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)):
|
||||
soft_tokens = tpumtjgetsofttokens()
|
||||
if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)):
|
||||
threading.Thread(
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
{
|
||||
"mtj_compat": "neo",
|
||||
"mtj_pe": "fixed",
|
||||
"mtj_config_map": {
|
||||
"d_model": "hidden_size",
|
||||
"n_heads": "num_heads",
|
||||
"layers": "num_layers"
|
||||
},
|
||||
"static_weights": {
|
||||
"transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}},
|
||||
"transformer.wpe.weight": {"mtj": {"module": "embedding_shard", "param": "pos_embs", "transforms": ["no_transpose"]}},
|
||||
"transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}},
|
||||
"transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}}
|
||||
},
|
||||
"layer_weights": {
|
||||
"transformer.h.{layer}.attn.attention.bias": {},
|
||||
"transformer.h.{layer}.attn.attention.masked_bias": {},
|
||||
"transformer.h.{layer}.attn.attention.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}},
|
||||
"transformer.h.{layer}.attn.attention.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}},
|
||||
"transformer.h.{layer}.attn.attention.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}},
|
||||
"transformer.h.{layer}.attn.attention.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}},
|
||||
"transformer.h.{layer}.attn.attention.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}},
|
||||
"transformer.h.{layer}.mlp.c_fc.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}},
|
||||
"transformer.h.{layer}.mlp.c_fc.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}},
|
||||
"transformer.h.{layer}.mlp.c_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}},
|
||||
"transformer.h.{layer}.mlp.c_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}},
|
||||
"transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}},
|
||||
"transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}},
|
||||
"transformer.h.{layer}.ln_2.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "scale"}},
|
||||
"transformer.h.{layer}.ln_2.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "offset"}}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
{
|
||||
"mtj_compat": "j",
|
||||
"mtj_pe": "rotary",
|
||||
"mtj_config_map": {
|
||||
"pe_rotary_dims": ["rotary_dim", 64],
|
||||
"d_model": "n_embd",
|
||||
"n_heads": "n_head",
|
||||
"layers": "n_layer"
|
||||
},
|
||||
"static_weights": {
|
||||
"transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}},
|
||||
"transformer.wte.bias": {"mtj": {"module": "embedding_shard/~/linear", "param": "b"}},
|
||||
"transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}},
|
||||
"transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}},
|
||||
"lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "transforms": ["vocab_pad"]}},
|
||||
"lm_head.bias": {"mtj": {"module": "projection_shard/~/linear", "param": "b"}}
|
||||
},
|
||||
"layer_weights": {
|
||||
"transformer.h.{layer}.attn.bias": {},
|
||||
"transformer.h.{layer}.attn.masked_bias": {},
|
||||
"transformer.h.{layer}.attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}},
|
||||
"transformer.h.{layer}.attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}},
|
||||
"transformer.h.{layer}.attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}},
|
||||
"transformer.h.{layer}.attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}},
|
||||
"transformer.h.{layer}.mlp.fc_in.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}},
|
||||
"transformer.h.{layer}.mlp.fc_in.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}},
|
||||
"transformer.h.{layer}.mlp.fc_out.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}},
|
||||
"transformer.h.{layer}.mlp.fc_out.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}},
|
||||
"transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}},
|
||||
"transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
{
|
||||
"mtj_compat": "fairseq_lm",
|
||||
"mtj_pe": "fairseq_sinusoidal",
|
||||
"mtj_config_map": {
|
||||
"d_model": "d_model",
|
||||
"n_heads": "attention_heads",
|
||||
"layers": "num_layers"
|
||||
},
|
||||
"static_weights": {
|
||||
"model.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}},
|
||||
"model.layer_norm.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}},
|
||||
"model.layer_norm.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}},
|
||||
"lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "transforms": ["vocab_pad"]}}
|
||||
},
|
||||
"layer_weights": {
|
||||
"model.layers.{layer}.self_attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}},
|
||||
"model.layers.{layer}.self_attn.q_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear", "param": "b"}},
|
||||
"model.layers.{layer}.self_attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}},
|
||||
"model.layers.{layer}.self_attn.v_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "b"}},
|
||||
"model.layers.{layer}.self_attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}},
|
||||
"model.layers.{layer}.self_attn.k_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "b"}},
|
||||
"model.layers.{layer}.self_attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}},
|
||||
"model.layers.{layer}.self_attn.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}},
|
||||
"model.layers.{layer}.fc1.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}},
|
||||
"model.layers.{layer}.fc1.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}},
|
||||
"model.layers.{layer}.fc2.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}},
|
||||
"model.layers.{layer}.fc2.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}},
|
||||
"model.layers.{layer}.self_attn_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}},
|
||||
"model.layers.{layer}.self_attn_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}},
|
||||
"model.layers.{layer}.final_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "scale"}},
|
||||
"model.layers.{layer}.final_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "offset"}}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,250 @@
|
|||
'''
|
||||
This file is AGPL-licensed.
|
||||
|
||||
Some of the code in this file is copied from PyTorch.
|
||||
|
||||
The license for PyTorch is shown below:
|
||||
|
||||
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
||||
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
||||
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
||||
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
||||
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
||||
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
||||
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
||||
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
||||
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
|
||||
and IDIAP Research Institute nor the names of its contributors may be
|
||||
used to endorse or promote products derived from this software without
|
||||
specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
POSSIBILITY OF SUCH DAMAGE.
|
||||
'''
|
||||
|
||||
|
||||
import contextlib
|
||||
from functools import reduce
|
||||
import itertools
|
||||
import zipfile
|
||||
import pickle
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
|
||||
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
||||
|
||||
|
||||
class LazyTensor:
|
||||
def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None):
|
||||
self.storage_type = storage_type
|
||||
self.key = key
|
||||
self.location = location
|
||||
self.seek_offset = seek_offset
|
||||
self.shape = shape
|
||||
self.stride = stride
|
||||
self.requires_grad = requires_grad
|
||||
self.backward_hooks = backward_hooks
|
||||
|
||||
def __view(self, f: Callable):
|
||||
return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, seek_offset={f(self.seek_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__view(repr)
|
||||
|
||||
def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None) -> torch.Tensor:
|
||||
size = reduce(lambda x, y: x * y, self.shape, 1)
|
||||
dtype = self.storage_type(0).dtype
|
||||
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
|
||||
if isinstance(checkpoint, zipfile.ZipFile):
|
||||
f = checkpoint.open(f"archive/data/{self.key}", "r")
|
||||
f.seek(self.seek_offset)
|
||||
else:
|
||||
f = checkpoint
|
||||
try:
|
||||
storage = self.storage_type.from_buffer(f.read(nbytes), "little")
|
||||
finally:
|
||||
if isinstance(checkpoint, zipfile.ZipFile):
|
||||
f.close()
|
||||
storage = torch.serialization._get_restore_location(map_location)(storage, self.location)
|
||||
tensor = torch.tensor([], dtype=storage.dtype, device=storage.device)
|
||||
tensor.set_(storage, 0, self.shape, self.stride)
|
||||
tensor.requires_grad = self.requires_grad
|
||||
tensor._backward_hooks = self.backward_hooks
|
||||
return tensor
|
||||
|
||||
|
||||
class _LazyUnpickler(pickle.Unpickler):
|
||||
lazy_loaded_storages: Dict[str, LazyTensor]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.lazy_loaded_storages = {}
|
||||
return super().__init__(*args, **kwargs)
|
||||
|
||||
def forced_persistent_load(self, saved_id):
|
||||
assert isinstance(saved_id, tuple)
|
||||
typename = saved_id[0]
|
||||
assert typename == "storage", f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
|
||||
storage_type, key, location, _ = saved_id[1:]
|
||||
return LazyTensor(storage_type, key, location)
|
||||
|
||||
def load(self, *args, **kwargs):
|
||||
self.persistent_load = self.forced_persistent_load
|
||||
retval = super().load(*args, **kwargs)
|
||||
self.lazy_loaded_storages = {}
|
||||
return retval
|
||||
|
||||
|
||||
def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride):
|
||||
lazy_storage.shape = shape
|
||||
lazy_storage.stride = stride
|
||||
dtype = lazy_storage.storage_type(0).dtype
|
||||
lazy_storage.seek_offset = storage_offset if dtype is torch.bool else storage_offset * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
|
||||
return lazy_storage
|
||||
|
||||
|
||||
# Modified version of https://github.com/pytorch/pytorch/blob/v1.11.0-rc4/torch/nn/modules/module.py#L1346-L1438
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
for name, param in local_state.items():
|
||||
key = prefix + name
|
||||
if key in state_dict:
|
||||
input_param = state_dict[key]
|
||||
if not torch.overrides.is_tensor_like(input_param):
|
||||
error_msgs.append('While copying the parameter named "{}", '
|
||||
'expected torch.Tensor or Tensor-like object from checkpoint but '
|
||||
'received {}'
|
||||
.format(key, type(input_param)))
|
||||
continue
|
||||
|
||||
# This is used to avoid copying uninitialized parameters into
|
||||
# non-lazy modules, since they dont have the hook to do the checks
|
||||
# in such case, it will error when accessing the .shape attribute.
|
||||
is_param_lazy = torch.nn.parameter.is_lazy(param)
|
||||
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
||||
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
|
||||
input_param = input_param[0]
|
||||
|
||||
if not is_param_lazy and input_param.shape != param.shape:
|
||||
# local shape should match the one in checkpoint
|
||||
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
||||
'the shape in current model is {}.'
|
||||
.format(key, input_param.shape, param.shape))
|
||||
continue
|
||||
try:
|
||||
with torch.no_grad():
|
||||
#param.copy_(input_param)
|
||||
new_param = torch.nn.Parameter(input_param, requires_grad=param.requires_grad) # This line is new
|
||||
if name in self._parameters: # This line is new
|
||||
self._parameters[name] = new_param # This line is new
|
||||
if name in persistent_buffers: # This line is new
|
||||
self._buffers[name] = new_param # This line is new
|
||||
except Exception as ex:
|
||||
error_msgs.append('While copying the parameter named "{}", '
|
||||
'whose dimensions in the model are {} and '
|
||||
'whose dimensions in the checkpoint are {}, '
|
||||
'an exception occurred : {}.'
|
||||
.format(key, param.size(), input_param.size(), ex.args))
|
||||
elif strict:
|
||||
missing_keys.append(key)
|
||||
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
|
||||
if extra_state_key in state_dict:
|
||||
self.set_extra_state(state_dict[extra_state_key])
|
||||
elif strict:
|
||||
missing_keys.append(extra_state_key)
|
||||
elif strict and (extra_state_key in state_dict):
|
||||
unexpected_keys.append(extra_state_key)
|
||||
|
||||
if strict:
|
||||
for key in state_dict.keys():
|
||||
if key.startswith(prefix) and key != extra_state_key:
|
||||
input_name = key[len(prefix):]
|
||||
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
|
||||
if input_name not in self._modules and input_name not in local_state:
|
||||
unexpected_keys.append(key)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False):
|
||||
if not enable:
|
||||
yield False
|
||||
return
|
||||
|
||||
try:
|
||||
old_unpickler = pickle.Unpickler
|
||||
pickle.Unpickler = _LazyUnpickler
|
||||
|
||||
old_rebuild_tensor = torch._utils._rebuild_tensor
|
||||
torch._utils._rebuild_tensor = _rebuild_tensor
|
||||
|
||||
old_torch_load = torch.load
|
||||
|
||||
def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
|
||||
retval = old_torch_load(f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args)
|
||||
if callback is not None:
|
||||
callback(retval, f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args)
|
||||
return retval
|
||||
|
||||
torch.load = torch_load
|
||||
|
||||
if dematerialized_modules:
|
||||
old_linear_init = torch.nn.Linear.__init__
|
||||
old_embedding_init = torch.nn.Embedding.__init__
|
||||
old_layernorm_init = torch.nn.LayerNorm.__init__
|
||||
|
||||
def linear_init(self, *args, device=None, **kwargs):
|
||||
return old_linear_init(self, *args, device="meta", **kwargs)
|
||||
|
||||
def embedding_init(self, *args, device=None, **kwargs):
|
||||
return old_embedding_init(self, *args, device="meta", **kwargs)
|
||||
|
||||
def layernorm_init(self, *args, device=None, **kwargs):
|
||||
return old_layernorm_init(self, *args, device="meta", **kwargs)
|
||||
|
||||
torch.nn.Linear.__init__ = linear_init
|
||||
torch.nn.Embedding.__init__ = embedding_init
|
||||
torch.nn.LayerNorm.__init__ = layernorm_init
|
||||
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
|
||||
torch.nn.Module._load_from_state_dict = _load_from_state_dict
|
||||
|
||||
yield True
|
||||
|
||||
finally:
|
||||
pickle.Unpickler = old_unpickler
|
||||
torch._utils._rebuild_tensor = old_rebuild_tensor
|
||||
torch.load = old_torch_load
|
||||
if dematerialized_modules:
|
||||
torch.nn.Linear.__init__ = old_linear_init
|
||||
torch.nn.Embedding.__init__ = old_embedding_init
|
||||
torch.nn.LayerNorm.__init__ = old_layernorm_init
|
||||
torch.nn.Module._load_from_state_dict = old_load_from_state_dict
|
|
@ -32,6 +32,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
|
|||
import progressbar
|
||||
import time
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import zipfile
|
||||
import requests
|
||||
import random
|
||||
import jax
|
||||
|
@ -41,9 +44,10 @@ import jax.numpy as jnp
|
|||
import numpy as np
|
||||
import optax
|
||||
import haiku as hk
|
||||
import transformers
|
||||
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
|
||||
from mesh_transformer.checkpoint import read_ckpt_lowmem
|
||||
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard
|
||||
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard, PlaceholderTensor
|
||||
from mesh_transformer.util import to_bf16
|
||||
|
||||
|
||||
params: Dict[str, Any] = {}
|
||||
|
@ -776,7 +780,26 @@ def infer_static(
|
|||
return samples
|
||||
|
||||
|
||||
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) -> None:
|
||||
def reshard_reverse(x, total_shards, old_shape):
|
||||
assert len(x.shape) != 1
|
||||
if len(x.shape) == 2:
|
||||
if old_shape[1] == x.shape[1]:
|
||||
out = x[0:1].tile((total_shards, 1))
|
||||
else:
|
||||
out = x.reshape(old_shape)
|
||||
elif len(x.shape) == 3:
|
||||
if x.shape[0] * x.shape[2] == old_shape[2]:
|
||||
out = x.reshape(old_shape)
|
||||
elif x.shape[0] * x.shape[1] == old_shape[1]:
|
||||
out = x.reshape((old_shape[1], old_shape[0], old_shape[2])).permute((1, 0, 2))
|
||||
else:
|
||||
assert False
|
||||
else:
|
||||
assert False
|
||||
return out
|
||||
|
||||
|
||||
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
|
||||
global thread_resources_env, seq, tokenizer, network, params
|
||||
|
||||
default_params = {
|
||||
|
@ -795,6 +818,53 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
|||
"tokenizer": "gpt2",
|
||||
}
|
||||
params = kwargs
|
||||
|
||||
# Try to convert HF config.json to MTJ config
|
||||
if hf_checkpoint:
|
||||
spec_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "maps", vars.model_type + ".json")
|
||||
if not os.path.isfile(spec_path):
|
||||
raise NotImplementedError(f"Unsupported model type {repr(vars.model_type)}")
|
||||
with open(spec_path) as f:
|
||||
lazy_load_spec = json.load(f)
|
||||
|
||||
if "mtj_compat" in lazy_load_spec:
|
||||
params["compat"] = lazy_load_spec["mtj_compat"]
|
||||
if "mtj_pe" in lazy_load_spec:
|
||||
params["pe"] = lazy_load_spec["mtj_pe"]
|
||||
for k, v in lazy_load_spec.get("mtj_config_map", {}).items():
|
||||
if type(v) is not list:
|
||||
params[k] = params[v]
|
||||
continue
|
||||
for i in range(len(v)):
|
||||
if i == len(v) - 1:
|
||||
params[k] = v[i]
|
||||
elif v[i] in params:
|
||||
params[k] = params[v[i]]
|
||||
break
|
||||
|
||||
params["n_vocab"] = params["vocab_size"]
|
||||
|
||||
if "activation_function" in params:
|
||||
params["activation"] = params["activation_function"]
|
||||
|
||||
# Both the number of attention heads in the model and the embedding
|
||||
# dimension of the model need to be divisible by the number of TPU cores
|
||||
# that we use, and JAX also requires the number of TPU cores used to be
|
||||
# an even number if we're using more than one core, so logically we try
|
||||
# to pick the largest possible even number of TPU cores such that the
|
||||
# number of attention heads and embedding dimension are both divisible
|
||||
# by the number of TPU cores, and fall back to one core if an even
|
||||
# number of TPU cores is not possible.
|
||||
for c in (8, 6, 4, 2, 1):
|
||||
if 0 == params["n_heads"] % c == params["d_model"] % c:
|
||||
params["cores_per_replica"] = c
|
||||
break
|
||||
|
||||
# The vocabulary size of the model also has to be divisible by the
|
||||
# number of TPU cores, so we pad the vocabulary with the minimum
|
||||
# possible number of dummy tokens such that it's divisible.
|
||||
params["n_vocab_padding"] = -(params["n_vocab"] % -params["cores_per_replica"])
|
||||
|
||||
if "compat" in params:
|
||||
default_params["compat"] = params["compat"]
|
||||
if default_params["compat"] == "fairseq_lm":
|
||||
|
@ -804,10 +874,11 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
|||
params[param] = default_params[param]
|
||||
|
||||
# Load tokenizer
|
||||
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
|
||||
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
|
||||
tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"])
|
||||
tokenizer = tokenizer_class.from_pretrained(params["tokenizer"])
|
||||
if not hf_checkpoint:
|
||||
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
|
||||
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
|
||||
tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"])
|
||||
tokenizer = tokenizer_class.from_pretrained(params["tokenizer"])
|
||||
|
||||
# Disable JAX warnings about these two functions having been renamed
|
||||
jax.host_count = jax.process_count
|
||||
|
@ -844,5 +915,147 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
|||
path += "/"
|
||||
|
||||
network = PenalizingCausalTransformer(params, dematerialized=True)
|
||||
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
|
||||
|
||||
if not hf_checkpoint:
|
||||
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
|
||||
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
|
||||
return
|
||||
|
||||
# Convert from HF checkpoint
|
||||
|
||||
move_xmap = jax.experimental.maps.xmap(
|
||||
fun=lambda x, _: to_bf16(x),
|
||||
in_axes=(["shard", ...], ["batch", ...]),
|
||||
out_axes=["shard", ...],
|
||||
axis_resources={'shard': 'mp', 'batch': 'dp'}
|
||||
)
|
||||
|
||||
model_spec = {}
|
||||
for key, spec in lazy_load_spec.get("static_weights", {}).items():
|
||||
if spec.get("mtj") is not None:
|
||||
model_spec[key] = spec["mtj"].copy()
|
||||
model_spec[key]["module"] = "causal_transformer_shard/~/" + model_spec[key]["module"]
|
||||
for _key, spec in lazy_load_spec.get("layer_weights", {}).items():
|
||||
for layer in range(params["layers"]):
|
||||
if spec.get("mtj") is not None:
|
||||
key = _key.format(layer=layer)
|
||||
model_spec[key] = spec["mtj"].copy()
|
||||
model_spec[key]["module"] = "causal_transformer_shard/~/" + model_spec[key]["module"].format(layer=layer)
|
||||
|
||||
import torch_lazy_loader
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
def callback(model_dict, f, **_):
|
||||
with zipfile.ZipFile(f, "r") as z:
|
||||
try:
|
||||
last_storage_key = None
|
||||
f = None
|
||||
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
|
||||
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
|
||||
|
||||
# Some model weights are used by transformers but not by MTJ.
|
||||
# We have to materialize these weights anyways because
|
||||
# transformers will throw a tantrum otherwise. To attain
|
||||
# the least possible memory usage, we create them as meta
|
||||
# tensors, which don't take up any actual CPU or TPU memory.
|
||||
if key not in model_spec:
|
||||
model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].storage_type(0).dtype, device="meta")
|
||||
continue
|
||||
|
||||
storage_key = model_dict[key].key
|
||||
if storage_key != last_storage_key:
|
||||
last_storage_key = storage_key
|
||||
if isinstance(f, zipfile.ZipExtFile):
|
||||
f.close()
|
||||
f = z.open(f"archive/data/{storage_key}")
|
||||
current_offset = f.tell()
|
||||
if current_offset != model_dict[key].seek_offset:
|
||||
f.seek(model_dict[key].seek_offset - current_offset, 1)
|
||||
spec = model_spec[key]
|
||||
transforms = set(spec.get("transforms", ()))
|
||||
if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor):
|
||||
error = f"Duplicate key {repr(key)}"
|
||||
print("\n\nERROR: " + error, file=sys.stderr)
|
||||
raise RuntimeError(error)
|
||||
tensor = model_dict[key].materialize(f, map_location="cpu")
|
||||
model_dict[key] = tensor.to("meta")
|
||||
|
||||
# MTJ requires certain mathematical operations to be performed
|
||||
# on tensors in order for them to be in the correct format
|
||||
if "divide_by_shards" in transforms:
|
||||
tensor /= params["cores_per_replica"]
|
||||
if "vocab_pad" in transforms:
|
||||
tensor = torch.nn.functional.pad(tensor, (0, 0, 0, params["n_vocab_padding"]))
|
||||
if "no_transpose" not in transforms:
|
||||
tensor = tensor.T
|
||||
tensor.unsqueeze_(0)
|
||||
|
||||
# Shard the tensor so that parts of the tensor can be used
|
||||
# on different TPU cores
|
||||
network.state["params"][spec["module"]][spec["param"]] = move_xmap(
|
||||
jnp.array(
|
||||
reshard_reverse(
|
||||
tensor,
|
||||
params["cores_per_replica"],
|
||||
network.state["params"][spec["module"]][spec["param"]].shape,
|
||||
),
|
||||
dtype=jnp.bfloat16,
|
||||
),
|
||||
np.empty(params["cores_per_replica"]),
|
||||
)
|
||||
|
||||
# Check for tensors that MTJ needs that were not provided in the
|
||||
# HF model
|
||||
for mk, mv in network.state["params"].items():
|
||||
for pk, pv in mv.items():
|
||||
if isinstance(pv, PlaceholderTensor):
|
||||
# The transformers GPT-J models apparently do not
|
||||
# have embedding bias, whereas MTJ GPT-J models do,
|
||||
# so we have to supplement an embedding bias tensor
|
||||
# by creating a tensor with the necessary shape, filled
|
||||
# with zeros.
|
||||
if mk == "causal_transformer_shard/~/embedding_shard/~/linear" and pk == "b":
|
||||
mv[pk] = move_xmap(jnp.zeros(mv[pk].shape, dtype=jnp.bfloat16), np.empty(params["cores_per_replica"]))
|
||||
|
||||
else:
|
||||
error = f"{mk} {pk} could not be found in the model checkpoint"
|
||||
print("\n\nERROR: " + error, file=sys.stderr)
|
||||
raise RuntimeError(error)
|
||||
finally:
|
||||
if isinstance(f, zipfile.ZipExtFile):
|
||||
f.close()
|
||||
|
||||
if os.path.isdir(vars.model.replace('/', '_')):
|
||||
import shutil
|
||||
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
|
||||
with torch_lazy_loader.use_lazy_torch_load(callback=callback, dematerialized_modules=True):
|
||||
if(os.path.isdir(vars.custmodpth)):
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache")
|
||||
except ValueError as e:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache")
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache")
|
||||
except ValueError as e:
|
||||
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache")
|
||||
elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))):
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
|
||||
except ValueError as e:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
|
||||
except ValueError as e:
|
||||
model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
|
||||
else:
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache")
|
||||
except ValueError as e:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache")
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache")
|
||||
except ValueError as e:
|
||||
model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache")
|
||||
|
||||
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
|
||||
|
|
Loading…
Reference in New Issue