Add lazy torch loading support to transformers backend

This commit is contained in:
Gnome Ann 2022-03-04 00:33:10 -05:00
parent 1515996fca
commit 58a2c18821
4 changed files with 196 additions and 41 deletions

View File

@ -26,6 +26,8 @@ import traceback
import threading
import markdown
import bleach
import itertools
import bisect
from collections.abc import Iterable
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
@ -248,6 +250,7 @@ class vars:
newlinemode = "n"
quiet = False # If set will suppress any story text from being printed to the console (will only be seen on the client web page)
debug = False # If set to true, will send debug information to the client for display
lazy_load = True # Whether or not to use torch_lazy_loader.py for transformers models in order to reduce CPU memory usage
utils.vars = vars
@ -335,10 +338,10 @@ def device_list(n_layers, primary=None, selected=None):
sep_color = colors.YELLOW
print(f"{row_color} {' '*9} N/A {sep_color}|{row_color} {n_layers:3} {sep_color}|{row_color} (CPU){colors.END}")
def device_config(model):
def device_config(config):
global breakmodel, generator
import breakmodel
n_layers = model.config.num_layers if hasattr(model.config, "num_layers") else model.config.n_layer
n_layers = config.num_layers if hasattr(config, "num_layers") else config.n_layer
if(args.breakmodel_gpulayers is not None):
try:
breakmodel.gpu_blocks = list(map(int, args.breakmodel_gpulayers.split(',')))
@ -411,22 +414,30 @@ def device_config(model):
# If all layers are on the same device, use the old GPU generation mode
while(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] == 0):
breakmodel.gpu_blocks.pop()
if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, model.config.num_layers if hasattr(model.config, "num_layers") else model.config.n_layer)):
if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, config.num_layers if hasattr(config, "num_layers") else config.n_layer)):
vars.breakmodel = False
vars.usegpu = True
vars.gpu_device = len(breakmodel.gpu_blocks)-1
model = model.half().to(vars.gpu_device)
generator = model.generate
return
if(not breakmodel.gpu_blocks):
print("Nothing assigned to a GPU, reverting to CPU only mode")
vars.breakmodel = False
vars.usegpu = False
model = model.to('cpu').float()
return
def move_model_to_devices(model):
global generator
if(not vars.breakmodel):
if(vars.usegpu):
model = model.half().to(vars.gpu_device)
else:
model = model.to('cpu').float()
generator = model.generate
return
model.half().to('cpu')
model.half()
gc.collect()
if(hasattr(model, "transformer")):
model.transformer.wte.to(breakmodel.primary_device)
@ -1013,6 +1024,67 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
import transformers.generation_utils
from transformers import __version__ as transformers_version
# Lazy loader
import torch_lazy_loader
def get_lazy_load_callback(n_layers):
if not vars.lazy_load:
return
from tqdm import tqdm
if "breakmodel" in globals():
gpu_blocks = breakmodel.gpu_blocks
ram_blocks = ram_blocks = n_layers - sum(gpu_blocks)
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
else:
ram_blocks = gpu_blocks = cumulative_gpu_blocks = None
def lazy_load_callback(model_dict, f, **_):
device_map = {}
for _key, spec in lazy_load_spec.get("layer_weights", {}).items():
for layer in range(n_layers):
key = _key.format(layer=layer)
if key not in model_dict:
continue
device = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" if not vars.hascuda or not vars.breakmodel or layer < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks)
device_map[key] = device
for key, value in model_dict.items():
if isinstance(value, torch_lazy_loader.LazyTensor) and key not in device_map:
device_map[key] = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu"
with zipfile.ZipFile(f, "r") as z:
try:
last_storage_key = None
f = None
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
storage_key = model_dict[key].key
if storage_key != last_storage_key:
last_storage_key = storage_key
if isinstance(f, zipfile.ZipExtFile):
f.close()
f = z.open(f"archive/data/{storage_key}")
current_offset = f.tell()
if current_offset != model_dict[key].seek_offset:
f.seek(model_dict[key].seek_offset - current_offset, 1)
device = device_map[key]
#print(f"Transferring <{key}> to {'(CPU)' if device == 'cpu' else '[device ' + str(device) + ']'} ... ", end="", flush=True)
model_dict[key] = model_dict[key].materialize(f, map_location=torch.device(device))
#print("OK", flush=True)
finally:
if isinstance(f, zipfile.ZipExtFile):
f.close()
return lazy_load_callback
if(vars.lazy_load and "model_config" in globals() and vars.model_type in ("gpt_neo", "gptj", "xglm")):
with open(os.path.join(path.dirname(path.realpath(__file__)), "maps", vars.model_type + ".json")) as f:
lazy_load_spec = json.load(f)
else:
vars.lazy_load = False
# Temporary fix for XGLM positional embedding issues until
# https://github.com/huggingface/transformers/issues/15736
# is resolved
@ -1247,6 +1319,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
# If custom GPT2 model was chosen
if(vars.model == "GPT2Custom"):
vars.lazy_load = False
model_config = open(vars.custmodpth + "/config.json", "r")
js = json.load(model_config)
with(maybe_use_float16()):
@ -1268,6 +1341,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
# feature yet
if(vars.model_type == "gpt2"):
lowmem = {}
# If we're using torch_lazy_loader, we need to get breakmodel config
# early so that it knows where to load the individual model tensors
if(vars.lazy_load and vars.hascuda and vars.breakmodel):
device_config(model_config)
# Download model from Huggingface if it does not exist, otherwise load locally
@ -1275,43 +1353,43 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
if os.path.isdir(vars.model.replace('/', '_')):
import shutil
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
if(os.path.isdir(vars.custmodpth)):
with(maybe_use_float16()):
try:
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem)
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem)
elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))):
with(maybe_use_float16()):
try:
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
else:
try:
tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache")
with(maybe_use_float16()):
with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(model_config.num_layers if hasattr(model_config, "num_layers") else model_config.n_layer), dematerialized_modules=True):
if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
lowmem = {}
if(os.path.isdir(vars.custmodpth)):
try:
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem)
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem)
elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))):
try:
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
else:
try:
tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache", **lowmem)
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache", **lowmem)
if not args.colab:
import shutil
model = model.half()
model.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
shutil.rmtree("cache/")
if not args.colab:
import shutil
model = model.half()
model.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
shutil.rmtree("cache/")
if(vars.hascuda):
if(vars.usegpu):
@ -1320,7 +1398,9 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
generator = model.generate
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
vars.modeldim = get_hidden_size_from_model(model)
device_config(model)
if(not vars.lazy_load):
device_config(model.config)
move_model_to_devices(model)
else:
model = model.to('cpu').float()
vars.modeldim = get_hidden_size_from_model(model)

25
maps/gpt_neo.json Normal file
View File

@ -0,0 +1,25 @@
{
"static_weights": {
"transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}},
"transformer.wpe.weight": {"mtj": {"module": "embedding_shard/~/pos_embs", "param": "w", "axis": 2, "transforms": ["transpose"]}},
"transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}},
"transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}}
},
"layer_weights": {
"transformer.h.{layer}.attn.attention.bias": {},
"transformer.h.{layer}.attn.attention.masked_bias": {},
"transformer.h.{layer}.attn.attention.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}},
"transformer.h.{layer}.attn.attention.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}},
"transformer.h.{layer}.attn.attention.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}},
"transformer.h.{layer}.attn.attention.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}},
"transformer.h.{layer}.attn.attention.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}},
"transformer.h.{layer}.mlp.c_fc.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}},
"transformer.h.{layer}.mlp.c_fc.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}},
"transformer.h.{layer}.mlp.c_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}},
"transformer.h.{layer}.mlp.c_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}},
"transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}},
"transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}},
"transformer.h.{layer}.ln_2.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "scale"}},
"transformer.h.{layer}.ln_2.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "offset"}}
}
}

24
maps/gptj.json Normal file
View File

@ -0,0 +1,24 @@
{
"static_weights": {
"transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}},
"transformer.wte.bias": {"mtj": {"module": "embedding_shard/~/linear", "param": "b"}},
"transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}},
"transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}},
"lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "axis": 2, "transforms": ["vocab_pad"]}},
"lm_head.bias": {"mtj": {"module": "projection_shard/~/linear", "param": "b"}}
},
"layer_weights": {
"transformer.h.{layer}.attn.bias": {},
"transformer.h.{layer}.attn.masked_bias": {},
"transformer.h.{layer}.attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}},
"transformer.h.{layer}.attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}},
"transformer.h.{layer}.attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}},
"transformer.h.{layer}.attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}},
"transformer.h.{layer}.mlp.fc_in.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}},
"transformer.h.{layer}.mlp.fc_in.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}},
"transformer.h.{layer}.mlp.fc_out.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}},
"transformer.h.{layer}.mlp.fc_out.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}},
"transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}},
"transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}}
}
}

26
maps/xglm.json Normal file
View File

@ -0,0 +1,26 @@
{
"static_weights": {
"model.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}},
"model.layer_norm.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}},
"model.layer_norm.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}},
"lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "axis": 2, "transforms": ["vocab_pad"]}}
},
"layer_weights": {
"model.layers.{layer}.self_attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}},
"model.layers.{layer}.self_attn.q_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear", "param": "b", "axis": 1}},
"model.layers.{layer}.self_attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}},
"model.layers.{layer}.self_attn.v_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "b", "axis": 1}},
"model.layers.{layer}.self_attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}},
"model.layers.{layer}.self_attn.k_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "b", "axis": 1}},
"model.layers.{layer}.self_attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}},
"model.layers.{layer}.self_attn.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}},
"model.layers.{layer}.fc1.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}},
"model.layers.{layer}.fc1.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}},
"model.layers.{layer}.fc2.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}},
"model.layers.{layer}.fc2.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}},
"model.layers.{layer}.self_attn_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}},
"model.layers.{layer}.self_attn_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}},
"model.layers.{layer}.final_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "scale"}},
"model.layers.{layer}.final_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "offset"}}
}
}