diff --git a/aiserver.py b/aiserver.py index 66c0e197..2a22d6b1 100644 --- a/aiserver.py +++ b/aiserver.py @@ -26,6 +26,8 @@ import traceback import threading import markdown import bleach +import itertools +import bisect from collections.abc import Iterable from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List @@ -248,6 +250,7 @@ class vars: newlinemode = "n" quiet = False # If set will suppress any story text from being printed to the console (will only be seen on the client web page) debug = False # If set to true, will send debug information to the client for display + lazy_load = True # Whether or not to use torch_lazy_loader.py for transformers models in order to reduce CPU memory usage utils.vars = vars @@ -335,10 +338,10 @@ def device_list(n_layers, primary=None, selected=None): sep_color = colors.YELLOW print(f"{row_color} {' '*9} N/A {sep_color}|{row_color} {n_layers:3} {sep_color}|{row_color} (CPU){colors.END}") -def device_config(model): +def device_config(config): global breakmodel, generator import breakmodel - n_layers = model.config.num_layers if hasattr(model.config, "num_layers") else model.config.n_layer + n_layers = config.num_layers if hasattr(config, "num_layers") else config.n_layer if(args.breakmodel_gpulayers is not None): try: breakmodel.gpu_blocks = list(map(int, args.breakmodel_gpulayers.split(','))) @@ -411,22 +414,30 @@ def device_config(model): # If all layers are on the same device, use the old GPU generation mode while(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] == 0): breakmodel.gpu_blocks.pop() - if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, model.config.num_layers if hasattr(model.config, "num_layers") else model.config.n_layer)): + if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, config.num_layers if hasattr(config, "num_layers") else config.n_layer)): vars.breakmodel = False vars.usegpu = True vars.gpu_device = len(breakmodel.gpu_blocks)-1 - model = model.half().to(vars.gpu_device) - generator = model.generate return if(not breakmodel.gpu_blocks): print("Nothing assigned to a GPU, reverting to CPU only mode") vars.breakmodel = False vars.usegpu = False - model = model.to('cpu').float() + return + +def move_model_to_devices(model): + global generator + + if(not vars.breakmodel): + if(vars.usegpu): + model = model.half().to(vars.gpu_device) + else: + model = model.to('cpu').float() generator = model.generate return - model.half().to('cpu') + + model.half() gc.collect() if(hasattr(model, "transformer")): model.transformer.wte.to(breakmodel.primary_device) @@ -1013,6 +1024,67 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme import transformers.generation_utils from transformers import __version__ as transformers_version + # Lazy loader + import torch_lazy_loader + def get_lazy_load_callback(n_layers): + if not vars.lazy_load: + return + + from tqdm import tqdm + + if "breakmodel" in globals(): + gpu_blocks = breakmodel.gpu_blocks + ram_blocks = ram_blocks = n_layers - sum(gpu_blocks) + cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks)) + else: + ram_blocks = gpu_blocks = cumulative_gpu_blocks = None + + def lazy_load_callback(model_dict, f, **_): + device_map = {} + + for _key, spec in lazy_load_spec.get("layer_weights", {}).items(): + for layer in range(n_layers): + key = _key.format(layer=layer) + if key not in model_dict: + continue + device = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" if not vars.hascuda or not vars.breakmodel or layer < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks) + device_map[key] = device + + for key, value in model_dict.items(): + if isinstance(value, torch_lazy_loader.LazyTensor) and key not in device_map: + device_map[key] = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" + + with zipfile.ZipFile(f, "r") as z: + try: + last_storage_key = None + f = None + for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"): + storage_key = model_dict[key].key + if storage_key != last_storage_key: + last_storage_key = storage_key + if isinstance(f, zipfile.ZipExtFile): + f.close() + f = z.open(f"archive/data/{storage_key}") + current_offset = f.tell() + if current_offset != model_dict[key].seek_offset: + f.seek(model_dict[key].seek_offset - current_offset, 1) + device = device_map[key] + #print(f"Transferring <{key}> to {'(CPU)' if device == 'cpu' else '[device ' + str(device) + ']'} ... ", end="", flush=True) + model_dict[key] = model_dict[key].materialize(f, map_location=torch.device(device)) + #print("OK", flush=True) + finally: + if isinstance(f, zipfile.ZipExtFile): + f.close() + + return lazy_load_callback + + if(vars.lazy_load and "model_config" in globals() and vars.model_type in ("gpt_neo", "gptj", "xglm")): + with open(os.path.join(path.dirname(path.realpath(__file__)), "maps", vars.model_type + ".json")) as f: + lazy_load_spec = json.load(f) + + else: + vars.lazy_load = False + # Temporary fix for XGLM positional embedding issues until # https://github.com/huggingface/transformers/issues/15736 # is resolved @@ -1247,6 +1319,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme # If custom GPT2 model was chosen if(vars.model == "GPT2Custom"): + vars.lazy_load = False model_config = open(vars.custmodpth + "/config.json", "r") js = json.load(model_config) with(maybe_use_float16()): @@ -1268,6 +1341,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme # feature yet if(vars.model_type == "gpt2"): lowmem = {} + + # If we're using torch_lazy_loader, we need to get breakmodel config + # early so that it knows where to load the individual model tensors + if(vars.lazy_load and vars.hascuda and vars.breakmodel): + device_config(model_config) # Download model from Huggingface if it does not exist, otherwise load locally @@ -1275,43 +1353,43 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme if os.path.isdir(vars.model.replace('/', '_')): import shutil shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_'))) - if(os.path.isdir(vars.custmodpth)): - with(maybe_use_float16()): - try: - tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache") - except ValueError as e: - tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache") - try: - model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem) - except ValueError as e: - model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem) - elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))): - with(maybe_use_float16()): - try: - tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") - except ValueError as e: - tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") - try: - model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) - except ValueError as e: - model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) - else: - try: - tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache") - except ValueError as e: - tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache") - with(maybe_use_float16()): + with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(model_config.num_layers if hasattr(model_config, "num_layers") else model_config.n_layer), dematerialized_modules=True): + if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time + lowmem = {} + if(os.path.isdir(vars.custmodpth)): + try: + tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache") + except ValueError as e: + tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache") + try: + model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem) + except ValueError as e: + model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem) + elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))): + try: + tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") + except ValueError as e: + tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") + try: + model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) + except ValueError as e: + model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) + else: + try: + tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache") + except ValueError as e: + tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache") try: model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache", **lowmem) except ValueError as e: model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache", **lowmem) - - if not args.colab: - import shutil - model = model.half() - model.save_pretrained("models/{}".format(vars.model.replace('/', '_'))) - tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_'))) - shutil.rmtree("cache/") + + if not args.colab: + import shutil + model = model.half() + model.save_pretrained("models/{}".format(vars.model.replace('/', '_'))) + tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_'))) + shutil.rmtree("cache/") if(vars.hascuda): if(vars.usegpu): @@ -1320,7 +1398,9 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme generator = model.generate elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel) vars.modeldim = get_hidden_size_from_model(model) - device_config(model) + if(not vars.lazy_load): + device_config(model.config) + move_model_to_devices(model) else: model = model.to('cpu').float() vars.modeldim = get_hidden_size_from_model(model) diff --git a/maps/gpt_neo.json b/maps/gpt_neo.json new file mode 100644 index 00000000..a93fb26e --- /dev/null +++ b/maps/gpt_neo.json @@ -0,0 +1,25 @@ +{ + "static_weights": { + "transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}}, + "transformer.wpe.weight": {"mtj": {"module": "embedding_shard/~/pos_embs", "param": "w", "axis": 2, "transforms": ["transpose"]}}, + "transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}}, + "transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}} + }, + "layer_weights": { + "transformer.h.{layer}.attn.attention.bias": {}, + "transformer.h.{layer}.attn.attention.masked_bias": {}, + "transformer.h.{layer}.attn.attention.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}}, + "transformer.h.{layer}.attn.attention.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}}, + "transformer.h.{layer}.attn.attention.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}}, + "transformer.h.{layer}.attn.attention.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}}, + "transformer.h.{layer}.attn.attention.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}}, + "transformer.h.{layer}.mlp.c_fc.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}}, + "transformer.h.{layer}.mlp.c_fc.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}}, + "transformer.h.{layer}.mlp.c_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}}, + "transformer.h.{layer}.mlp.c_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}}, + "transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}}, + "transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}}, + "transformer.h.{layer}.ln_2.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "scale"}}, + "transformer.h.{layer}.ln_2.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "offset"}} + } +} diff --git a/maps/gptj.json b/maps/gptj.json new file mode 100644 index 00000000..51a788d7 --- /dev/null +++ b/maps/gptj.json @@ -0,0 +1,24 @@ +{ + "static_weights": { + "transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}}, + "transformer.wte.bias": {"mtj": {"module": "embedding_shard/~/linear", "param": "b"}}, + "transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}}, + "transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}}, + "lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "axis": 2, "transforms": ["vocab_pad"]}}, + "lm_head.bias": {"mtj": {"module": "projection_shard/~/linear", "param": "b"}} + }, + "layer_weights": { + "transformer.h.{layer}.attn.bias": {}, + "transformer.h.{layer}.attn.masked_bias": {}, + "transformer.h.{layer}.attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}}, + "transformer.h.{layer}.attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}}, + "transformer.h.{layer}.attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}}, + "transformer.h.{layer}.attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}}, + "transformer.h.{layer}.mlp.fc_in.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}}, + "transformer.h.{layer}.mlp.fc_in.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}}, + "transformer.h.{layer}.mlp.fc_out.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}}, + "transformer.h.{layer}.mlp.fc_out.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}}, + "transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}}, + "transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}} + } +} diff --git a/maps/xglm.json b/maps/xglm.json new file mode 100644 index 00000000..beb90985 --- /dev/null +++ b/maps/xglm.json @@ -0,0 +1,26 @@ +{ + "static_weights": { + "model.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}}, + "model.layer_norm.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}}, + "model.layer_norm.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}}, + "lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "axis": 2, "transforms": ["vocab_pad"]}} + }, + "layer_weights": { + "model.layers.{layer}.self_attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}}, + "model.layers.{layer}.self_attn.q_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear", "param": "b", "axis": 1}}, + "model.layers.{layer}.self_attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}}, + "model.layers.{layer}.self_attn.v_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "b", "axis": 1}}, + "model.layers.{layer}.self_attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}}, + "model.layers.{layer}.self_attn.k_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "b", "axis": 1}}, + "model.layers.{layer}.self_attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}}, + "model.layers.{layer}.self_attn.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}}, + "model.layers.{layer}.fc1.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}}, + "model.layers.{layer}.fc1.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}}, + "model.layers.{layer}.fc2.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}}, + "model.layers.{layer}.fc2.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}}, + "model.layers.{layer}.self_attn_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}}, + "model.layers.{layer}.self_attn_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}}, + "model.layers.{layer}.final_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "scale"}}, + "model.layers.{layer}.final_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "offset"}} + } +}