Add lazy torch loading support to transformers backend

This commit is contained in:
Gnome Ann
2022-03-04 00:33:10 -05:00
parent 1515996fca
commit 58a2c18821
4 changed files with 196 additions and 41 deletions

View File

@@ -26,6 +26,8 @@ import traceback
import threading import threading
import markdown import markdown
import bleach import bleach
import itertools
import bisect
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
@@ -248,6 +250,7 @@ class vars:
newlinemode = "n" newlinemode = "n"
quiet = False # If set will suppress any story text from being printed to the console (will only be seen on the client web page) quiet = False # If set will suppress any story text from being printed to the console (will only be seen on the client web page)
debug = False # If set to true, will send debug information to the client for display debug = False # If set to true, will send debug information to the client for display
lazy_load = True # Whether or not to use torch_lazy_loader.py for transformers models in order to reduce CPU memory usage
utils.vars = vars utils.vars = vars
@@ -335,10 +338,10 @@ def device_list(n_layers, primary=None, selected=None):
sep_color = colors.YELLOW sep_color = colors.YELLOW
print(f"{row_color} {' '*9} N/A {sep_color}|{row_color} {n_layers:3} {sep_color}|{row_color} (CPU){colors.END}") print(f"{row_color} {' '*9} N/A {sep_color}|{row_color} {n_layers:3} {sep_color}|{row_color} (CPU){colors.END}")
def device_config(model): def device_config(config):
global breakmodel, generator global breakmodel, generator
import breakmodel import breakmodel
n_layers = model.config.num_layers if hasattr(model.config, "num_layers") else model.config.n_layer n_layers = config.num_layers if hasattr(config, "num_layers") else config.n_layer
if(args.breakmodel_gpulayers is not None): if(args.breakmodel_gpulayers is not None):
try: try:
breakmodel.gpu_blocks = list(map(int, args.breakmodel_gpulayers.split(','))) breakmodel.gpu_blocks = list(map(int, args.breakmodel_gpulayers.split(',')))
@@ -411,22 +414,30 @@ def device_config(model):
# If all layers are on the same device, use the old GPU generation mode # If all layers are on the same device, use the old GPU generation mode
while(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] == 0): while(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] == 0):
breakmodel.gpu_blocks.pop() breakmodel.gpu_blocks.pop()
if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, model.config.num_layers if hasattr(model.config, "num_layers") else model.config.n_layer)): if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, config.num_layers if hasattr(config, "num_layers") else config.n_layer)):
vars.breakmodel = False vars.breakmodel = False
vars.usegpu = True vars.usegpu = True
vars.gpu_device = len(breakmodel.gpu_blocks)-1 vars.gpu_device = len(breakmodel.gpu_blocks)-1
model = model.half().to(vars.gpu_device)
generator = model.generate
return return
if(not breakmodel.gpu_blocks): if(not breakmodel.gpu_blocks):
print("Nothing assigned to a GPU, reverting to CPU only mode") print("Nothing assigned to a GPU, reverting to CPU only mode")
vars.breakmodel = False vars.breakmodel = False
vars.usegpu = False vars.usegpu = False
model = model.to('cpu').float() return
def move_model_to_devices(model):
global generator
if(not vars.breakmodel):
if(vars.usegpu):
model = model.half().to(vars.gpu_device)
else:
model = model.to('cpu').float()
generator = model.generate generator = model.generate
return return
model.half().to('cpu')
model.half()
gc.collect() gc.collect()
if(hasattr(model, "transformer")): if(hasattr(model, "transformer")):
model.transformer.wte.to(breakmodel.primary_device) model.transformer.wte.to(breakmodel.primary_device)
@@ -1013,6 +1024,67 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
import transformers.generation_utils import transformers.generation_utils
from transformers import __version__ as transformers_version from transformers import __version__ as transformers_version
# Lazy loader
import torch_lazy_loader
def get_lazy_load_callback(n_layers):
if not vars.lazy_load:
return
from tqdm import tqdm
if "breakmodel" in globals():
gpu_blocks = breakmodel.gpu_blocks
ram_blocks = ram_blocks = n_layers - sum(gpu_blocks)
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
else:
ram_blocks = gpu_blocks = cumulative_gpu_blocks = None
def lazy_load_callback(model_dict, f, **_):
device_map = {}
for _key, spec in lazy_load_spec.get("layer_weights", {}).items():
for layer in range(n_layers):
key = _key.format(layer=layer)
if key not in model_dict:
continue
device = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" if not vars.hascuda or not vars.breakmodel or layer < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks)
device_map[key] = device
for key, value in model_dict.items():
if isinstance(value, torch_lazy_loader.LazyTensor) and key not in device_map:
device_map[key] = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu"
with zipfile.ZipFile(f, "r") as z:
try:
last_storage_key = None
f = None
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
storage_key = model_dict[key].key
if storage_key != last_storage_key:
last_storage_key = storage_key
if isinstance(f, zipfile.ZipExtFile):
f.close()
f = z.open(f"archive/data/{storage_key}")
current_offset = f.tell()
if current_offset != model_dict[key].seek_offset:
f.seek(model_dict[key].seek_offset - current_offset, 1)
device = device_map[key]
#print(f"Transferring <{key}> to {'(CPU)' if device == 'cpu' else '[device ' + str(device) + ']'} ... ", end="", flush=True)
model_dict[key] = model_dict[key].materialize(f, map_location=torch.device(device))
#print("OK", flush=True)
finally:
if isinstance(f, zipfile.ZipExtFile):
f.close()
return lazy_load_callback
if(vars.lazy_load and "model_config" in globals() and vars.model_type in ("gpt_neo", "gptj", "xglm")):
with open(os.path.join(path.dirname(path.realpath(__file__)), "maps", vars.model_type + ".json")) as f:
lazy_load_spec = json.load(f)
else:
vars.lazy_load = False
# Temporary fix for XGLM positional embedding issues until # Temporary fix for XGLM positional embedding issues until
# https://github.com/huggingface/transformers/issues/15736 # https://github.com/huggingface/transformers/issues/15736
# is resolved # is resolved
@@ -1247,6 +1319,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
# If custom GPT2 model was chosen # If custom GPT2 model was chosen
if(vars.model == "GPT2Custom"): if(vars.model == "GPT2Custom"):
vars.lazy_load = False
model_config = open(vars.custmodpth + "/config.json", "r") model_config = open(vars.custmodpth + "/config.json", "r")
js = json.load(model_config) js = json.load(model_config)
with(maybe_use_float16()): with(maybe_use_float16()):
@@ -1269,49 +1342,54 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
if(vars.model_type == "gpt2"): if(vars.model_type == "gpt2"):
lowmem = {} lowmem = {}
# If we're using torch_lazy_loader, we need to get breakmodel config
# early so that it knows where to load the individual model tensors
if(vars.lazy_load and vars.hascuda and vars.breakmodel):
device_config(model_config)
# Download model from Huggingface if it does not exist, otherwise load locally # Download model from Huggingface if it does not exist, otherwise load locally
#If we specify a model and it's in the root directory, we need to move it to the models directory (legacy folder structure to new) #If we specify a model and it's in the root directory, we need to move it to the models directory (legacy folder structure to new)
if os.path.isdir(vars.model.replace('/', '_')): if os.path.isdir(vars.model.replace('/', '_')):
import shutil import shutil
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_'))) shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
if(os.path.isdir(vars.custmodpth)): with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(model_config.num_layers if hasattr(model_config, "num_layers") else model_config.n_layer), dematerialized_modules=True):
with(maybe_use_float16()): if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
try: lowmem = {}
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache") if(os.path.isdir(vars.custmodpth)):
except ValueError as e: try:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache") tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache")
try: except ValueError as e:
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem) tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache")
except ValueError as e: try:
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem) model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem)
elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))): except ValueError as e:
with(maybe_use_float16()): model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem)
try: elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))):
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") try:
except ValueError as e: tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") except ValueError as e:
try: tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) try:
except ValueError as e: model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) except ValueError as e:
else: model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
try: else:
tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache") try:
except ValueError as e: tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache")
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache") except ValueError as e:
with(maybe_use_float16()): tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache")
try: try:
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache", **lowmem) model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache", **lowmem)
except ValueError as e: except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache", **lowmem) model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache", **lowmem)
if not args.colab: if not args.colab:
import shutil import shutil
model = model.half() model = model.half()
model.save_pretrained("models/{}".format(vars.model.replace('/', '_'))) model.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_'))) tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
shutil.rmtree("cache/") shutil.rmtree("cache/")
if(vars.hascuda): if(vars.hascuda):
if(vars.usegpu): if(vars.usegpu):
@@ -1320,7 +1398,9 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
generator = model.generate generator = model.generate
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel) elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
vars.modeldim = get_hidden_size_from_model(model) vars.modeldim = get_hidden_size_from_model(model)
device_config(model) if(not vars.lazy_load):
device_config(model.config)
move_model_to_devices(model)
else: else:
model = model.to('cpu').float() model = model.to('cpu').float()
vars.modeldim = get_hidden_size_from_model(model) vars.modeldim = get_hidden_size_from_model(model)

25
maps/gpt_neo.json Normal file
View File

@@ -0,0 +1,25 @@
{
"static_weights": {
"transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}},
"transformer.wpe.weight": {"mtj": {"module": "embedding_shard/~/pos_embs", "param": "w", "axis": 2, "transforms": ["transpose"]}},
"transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}},
"transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}}
},
"layer_weights": {
"transformer.h.{layer}.attn.attention.bias": {},
"transformer.h.{layer}.attn.attention.masked_bias": {},
"transformer.h.{layer}.attn.attention.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}},
"transformer.h.{layer}.attn.attention.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}},
"transformer.h.{layer}.attn.attention.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}},
"transformer.h.{layer}.attn.attention.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}},
"transformer.h.{layer}.attn.attention.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}},
"transformer.h.{layer}.mlp.c_fc.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}},
"transformer.h.{layer}.mlp.c_fc.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}},
"transformer.h.{layer}.mlp.c_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}},
"transformer.h.{layer}.mlp.c_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}},
"transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}},
"transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}},
"transformer.h.{layer}.ln_2.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "scale"}},
"transformer.h.{layer}.ln_2.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "offset"}}
}
}

24
maps/gptj.json Normal file
View File

@@ -0,0 +1,24 @@
{
"static_weights": {
"transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}},
"transformer.wte.bias": {"mtj": {"module": "embedding_shard/~/linear", "param": "b"}},
"transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}},
"transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}},
"lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "axis": 2, "transforms": ["vocab_pad"]}},
"lm_head.bias": {"mtj": {"module": "projection_shard/~/linear", "param": "b"}}
},
"layer_weights": {
"transformer.h.{layer}.attn.bias": {},
"transformer.h.{layer}.attn.masked_bias": {},
"transformer.h.{layer}.attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}},
"transformer.h.{layer}.attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}},
"transformer.h.{layer}.attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}},
"transformer.h.{layer}.attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}},
"transformer.h.{layer}.mlp.fc_in.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}},
"transformer.h.{layer}.mlp.fc_in.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}},
"transformer.h.{layer}.mlp.fc_out.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}},
"transformer.h.{layer}.mlp.fc_out.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}},
"transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}},
"transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}}
}
}

26
maps/xglm.json Normal file
View File

@@ -0,0 +1,26 @@
{
"static_weights": {
"model.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}},
"model.layer_norm.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}},
"model.layer_norm.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}},
"lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "axis": 2, "transforms": ["vocab_pad"]}}
},
"layer_weights": {
"model.layers.{layer}.self_attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}},
"model.layers.{layer}.self_attn.q_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear", "param": "b", "axis": 1}},
"model.layers.{layer}.self_attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}},
"model.layers.{layer}.self_attn.v_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "b", "axis": 1}},
"model.layers.{layer}.self_attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}},
"model.layers.{layer}.self_attn.k_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "b", "axis": 1}},
"model.layers.{layer}.self_attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}},
"model.layers.{layer}.self_attn.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}},
"model.layers.{layer}.fc1.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}},
"model.layers.{layer}.fc1.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}},
"model.layers.{layer}.fc2.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}},
"model.layers.{layer}.fc2.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}},
"model.layers.{layer}.self_attn_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}},
"model.layers.{layer}.self_attn_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}},
"model.layers.{layer}.final_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "scale"}},
"model.layers.{layer}.final_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "offset"}}
}
}