From f5e689a725eb3ad450e80aa4a05d9890b84b5630 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Thu, 12 May 2022 19:09:31 -0400 Subject: [PATCH 1/8] Upload maps/opt.json and update requirements --- maps/opt.json | 32 ++++++++++++++++++++++++++++++++ requirements.txt | 2 +- requirements_mtj.txt | 4 ++-- tpu_mtj_backend.py | 2 ++ 4 files changed, 37 insertions(+), 3 deletions(-) create mode 100644 maps/opt.json diff --git a/maps/opt.json b/maps/opt.json new file mode 100644 index 00000000..b9667ac9 --- /dev/null +++ b/maps/opt.json @@ -0,0 +1,32 @@ +{ + "mtj_compat": "opt", + "mtj_pe": "fixed", + "mtj_config_map": { + "do_layer_norm_before": ["do_layer_norm_before", true], + "d_model": "hidden_size", + "n_heads": "num_attention_heads", + "layers": "num_hidden_layers" + }, + "static_weights": { + "decoder.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}}, + "decoder.embed_positions.weight": {"mtj": {"module": "embedding_shard", "param": "pos_embs", "transforms": ["no_transpose", "remove_first_two_rows"]}} + }, + "layer_weights": { + "decoder.layers.{layer}.self_attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}}, + "decoder.layers.{layer}.self_attn.q_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear", "param": "b"}}, + "decoder.layers.{layer}.self_attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}}, + "decoder.layers.{layer}.self_attn.v_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "b"}}, + "decoder.layers.{layer}.self_attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}}, + "decoder.layers.{layer}.self_attn.k_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "b"}}, + "decoder.layers.{layer}.self_attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}}, + "decoder.layers.{layer}.self_attn.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}}, + "decoder.layers.{layer}.fc1.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}}, + "decoder.layers.{layer}.fc1.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}}, + "decoder.layers.{layer}.fc2.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}}, + "decoder.layers.{layer}.fc2.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}}, + "decoder.layers.{layer}.self_attn_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}}, + "decoder.layers.{layer}.self_attn_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}}, + "decoder.layers.{layer}.final_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "scale"}}, + "decoder.layers.{layer}.final_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "offset"}} + } +} diff --git a/requirements.txt b/requirements.txt index 897f9e8e..7b5b967c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -transformers>=4.17 +transformers>=4.19 Flask Flask-SocketIO requests diff --git a/requirements_mtj.txt b/requirements_mtj.txt index 416a06a4..e2a6c4e1 100644 --- a/requirements_mtj.txt +++ b/requirements_mtj.txt @@ -5,9 +5,9 @@ requests optax >= 0.0.5, <= 0.0.9 dm-haiku == 0.0.5 jax == 0.2.21 -transformers >= 4.17 +transformers >= 4.19 progressbar2 -git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck +git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck-staging flask Flask-SocketIO flask-cloudflared >= 0.0.5 diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 1f67763f..75b4ee9c 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1200,6 +1200,8 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo # MTJ requires certain mathematical operations to be performed # on tensors in order for them to be in the correct format + if "remove_first_two_rows" in transforms: + tensor = tensor[2:] if "divide_by_shards" in transforms: tensor /= params["cores_per_replica"] if "vocab_pad" in transforms: From 4fa5f1cd6afb3486704870c2e56e84f5888e7f71 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Thu, 12 May 2022 22:21:15 -0400 Subject: [PATCH 2/8] Add TPU support for OPT-350M The 350M model seems to have a different structure than the other ones ??? --- aiserver.py | 8 ++++---- maps/opt.json | 5 ++++- tpu_mtj_backend.py | 2 +- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/aiserver.py b/aiserver.py index 0dc19c5a..1f105701 100644 --- a/aiserver.py +++ b/aiserver.py @@ -772,7 +772,7 @@ def spRequest(filename): tensor = tensor.reshape( tpu_mtj_backend.params["cores_per_replica"], -1, - tpu_mtj_backend.params["d_model"], + tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]), ) vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor)) else: @@ -1574,14 +1574,14 @@ else: global np if 'np' not in globals(): import numpy as np - tensor = np.zeros((1, tpu_mtj_backend.params["d_model"]), dtype=np.float32) + tensor = np.zeros((1, tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"])), dtype=np.float32) rows = tensor.shape[0] padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) tensor = tensor.reshape( tpu_mtj_backend.params["cores_per_replica"], -1, - tpu_mtj_backend.params["d_model"], + tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]), ) vars.sp = tpu_mtj_backend.shard_xmap(tensor) soft_tokens = np.arange( @@ -1682,7 +1682,7 @@ else: loadmodelsettings() loadsettings() tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and vars.use_colab_tpu, **vars.modelconfig) - vars.modeldim = int(tpu_mtj_backend.params["d_model"]) + vars.modeldim = int(tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"])) tokenizer = tpu_mtj_backend.tokenizer else: loadsettings() diff --git a/maps/opt.json b/maps/opt.json index b9667ac9..c99ae19f 100644 --- a/maps/opt.json +++ b/maps/opt.json @@ -3,13 +3,16 @@ "mtj_pe": "fixed", "mtj_config_map": { "do_layer_norm_before": ["do_layer_norm_before", true], + "d_embed": "word_embed_proj_dim", "d_model": "hidden_size", "n_heads": "num_attention_heads", "layers": "num_hidden_layers" }, "static_weights": { "decoder.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}}, - "decoder.embed_positions.weight": {"mtj": {"module": "embedding_shard", "param": "pos_embs", "transforms": ["no_transpose", "remove_first_two_rows"]}} + "decoder.project_in.weight": {"mtj": {"module": "embedding_shard", "param": "project_in"}}, + "decoder.embed_positions.weight": {"mtj": {"module": "embedding_shard", "param": "pos_embs", "transforms": ["no_transpose", "remove_first_two_rows"]}}, + "decoder.project_out.weight": {"mtj": {"module": "projection_shard", "param": "project_out"}} }, "layer_weights": { "decoder.layers.{layer}.self_attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}}, diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 75b4ee9c..b956648b 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1054,7 +1054,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo # by the number of TPU cores, and fall back to one core if an even # number of TPU cores is not possible. for c in (8, 6, 4, 2, 1): - if 0 == params["n_heads"] % c == params["d_model"] % c: + if 0 == params["n_heads"] % c == params.get("d_embed", params["d_model"]) % c: params["cores_per_replica"] = c break From b1d8797a54d4e22b6d43062930ed63765586bdc5 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Thu, 12 May 2022 23:51:40 -0400 Subject: [PATCH 3/8] Allow TPU Colab to load sharded HF models --- aiserver.py | 20 +++++++++++++++++++- requirements_mtj.txt | 2 +- tpu_mtj_backend.py | 12 ++++++++++-- utils.py | 14 +++++++++++++- 4 files changed, 43 insertions(+), 5 deletions(-) diff --git a/aiserver.py b/aiserver.py index 1f105701..6c9401e2 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1127,13 +1127,21 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go from transformers import __version__ as transformers_version from transformers import PreTrainedModel + from transformers import modeling_utils old_from_pretrained = PreTrainedModel.from_pretrained.__func__ @classmethod def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + utils.num_shards = None + utils.current_shard = 0 if not args.no_aria2: utils.aria2_hook(pretrained_model_name_or_path, **kwargs) return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) PreTrainedModel.from_pretrained = new_from_pretrained + old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files + def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs): + utils.num_shards = utils.get_num_shards(index_filename) + return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs) + modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files # Lazy loader import torch_lazy_loader @@ -1170,7 +1178,9 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go last_storage_key = None f = None current_offset = 0 - for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"): + if utils.num_shards is not None: + utils.current_shard += 1 + for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")): storage_key = model_dict[key].key if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset: last_storage_key = storage_key @@ -1560,13 +1570,21 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache") else: from transformers import PreTrainedModel + from transformers import modeling_utils old_from_pretrained = PreTrainedModel.from_pretrained.__func__ @classmethod def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + utils.num_shards = None + utils.current_shard = 0 if not args.no_aria2: utils.aria2_hook(pretrained_model_name_or_path, **kwargs) return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) PreTrainedModel.from_pretrained = new_from_pretrained + old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files + def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs): + utils.num_shards = utils.get_num_shards(index_filename) + return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs) + modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files def tpumtjgetsofttokens(): soft_tokens = None diff --git a/requirements_mtj.txt b/requirements_mtj.txt index e2a6c4e1..0f723a49 100644 --- a/requirements_mtj.txt +++ b/requirements_mtj.txt @@ -7,7 +7,7 @@ dm-haiku == 0.0.5 jax == 0.2.21 transformers >= 4.19 progressbar2 -git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck-staging +git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck flask Flask-SocketIO flask-cloudflared >= 0.0.5 diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index b956648b..2fa149d7 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -27,6 +27,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ''' +import utils + import multiprocessing from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar import progressbar @@ -1163,8 +1165,11 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo last_storage_key = None f = None current_offset = 0 - print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n") - for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"): + if utils.current_shard == 0: + print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n") + if utils.num_shards is not None: + utils.current_shard += 1 + for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")): # Some model weights are used by transformers but not by MTJ. # We have to materialize these weights anyways because @@ -1225,6 +1230,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo np.empty(params["cores_per_replica"]), ) + if utils.num_shards is not None and utils.current_shard < utils.num_shards: + return + # Check for tensors that MTJ needs that were not provided in the # HF model for mk, mv in network.state["params"].items(): diff --git a/utils.py b/utils.py index c6eb85ec..0fdfa125 100644 --- a/utils.py +++ b/utils.py @@ -6,8 +6,11 @@ import subprocess import tempfile import requests import os +from typing import Optional vars = None +num_shards: Optional[int] = None +current_shard = 0 #==================================================================# # Decorator to prevent a function's actions from being run until @@ -133,7 +136,7 @@ def decodenewlines(txt): return txt #==================================================================# -# Downloads sharded huggingface checkpoints using aria2c if possible +# Downloads huggingface checkpoints using aria2c if possible #==================================================================# def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs): import transformers @@ -225,3 +228,12 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n)) with open(os.path.join(_cache_dir, n + ".json"), "w") as f: json.dump({"url": u, "etag": t}, f) + +#==================================================================# +# Given the path to a pytorch_model.bin.index.json, returns how many +# shards there are in the model +#==================================================================# +def get_num_shards(filename): + with open(filename) as f: + map_data = json.load(f) + return len(set(map_data["weight_map"].values())) From defbb53b689fa448b4848eb257046934145c1d9b Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Fri, 13 May 2022 01:03:38 -0400 Subject: [PATCH 4/8] OPT breakmodel --- aiserver.py | 49 ++++++++++---- breakmodel.py | 182 ++++++++++++++++++++++++++++++++++++++++++++++++-- utils.py | 6 ++ 3 files changed, 220 insertions(+), 17 deletions(-) diff --git a/aiserver.py b/aiserver.py index 6c9401e2..f988e6ba 100644 --- a/aiserver.py +++ b/aiserver.py @@ -274,7 +274,7 @@ class vars: recentrngm = None # If a new random game was recently generated without Submitting after, this is the memory used (as a string), otherwise this is None useprompt = False # Whether to send the full prompt with every submit action breakmodel = False # For GPU users, whether to use both system RAM and VRAM to conserve VRAM while offering speedup compared to CPU-only - bmsupported = False # Whether the breakmodel option is supported (GPT-Neo/GPT-J/XGLM only, currently) + bmsupported = False # Whether the breakmodel option is supported (GPT-Neo/GPT-J/XGLM/OPT only, currently) nobreakmodel = False # Something specifically requested Breakmodel to be disabled (For example a models config) smandelete = False # Whether stories can be deleted from inside the browser smanrename = False # Whether stories can be renamed from inside the browser @@ -391,7 +391,7 @@ def device_list(n_layers, primary=None, selected=None): def device_config(config): global breakmodel, generator import breakmodel - n_layers = config.num_layers if hasattr(config, "num_layers") else config.n_layer + n_layers = utils.num_layers(config) if(args.breakmodel_gpulayers is not None): try: breakmodel.gpu_blocks = list(map(int, args.breakmodel_gpulayers.split(','))) @@ -464,7 +464,7 @@ def device_config(config): # If all layers are on the same device, use the old GPU generation mode while(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] == 0): breakmodel.gpu_blocks.pop() - if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, config.num_layers if hasattr(config, "num_layers") else config.n_layer)): + if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, utils.num_layers(config))): vars.breakmodel = False vars.usegpu = True vars.gpu_device = len(breakmodel.gpu_blocks)-1 @@ -496,22 +496,33 @@ def move_model_to_devices(model): model.lm_head.to(breakmodel.primary_device) if(hasattr(model.transformer, 'wpe')): model.transformer.wpe.to(breakmodel.primary_device) - else: + elif(not hasattr(model.model, "decoder")): model.model.embed_tokens.to(breakmodel.primary_device) model.model.layer_norm.to(breakmodel.primary_device) model.lm_head.to(breakmodel.primary_device) model.model.embed_positions.to(breakmodel.primary_device) + else: + model.model.decoder.embed_tokens.to(breakmodel.primary_device) + if(model.model.decoder.project_in is not None): + model.model.decoder.project_in.to(breakmodel.primary_device) + if(model.model.decoder.project_out is not None): + model.model.decoder.project_out.to(breakmodel.primary_device) + model.model.decoder.embed_positions.to(breakmodel.primary_device) gc.collect() GPTNeoModel.forward = breakmodel.new_forward_neo if("GPTJModel" in globals()): GPTJModel.forward = breakmodel.new_forward_neo # type: ignore if("XGLMModel" in globals()): XGLMModel.forward = breakmodel.new_forward_xglm # type: ignore + if("OPTDecoder" in globals()): + OPTDecoder.forward = breakmodel.new_forward_opt # type: ignore generator = model.generate if(hasattr(model, "transformer")): breakmodel.move_hidden_layers(model.transformer) - else: + elif(not hasattr(model.model, "decoder")): breakmodel.move_hidden_layers(model.model, model.model.layers) + else: + breakmodel.move_hidden_layers(model.model.decoder, model.model.decoder.layers) #==================================================================# # Allow the models to override some settings @@ -911,7 +922,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go loadsettings() print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="") vars.hascuda = torch.cuda.is_available() - vars.bmsupported = vars.model_type in ("gpt_neo", "gptj", "xglm") and not vars.nobreakmodel + vars.bmsupported = vars.model_type in ("gpt_neo", "gptj", "xglm", "opt") and not vars.nobreakmodel if(args.breakmodel is not None and args.breakmodel): print("WARNING: --breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --breakmodel_gpulayers is used (see --help for details).", file=sys.stderr) if(args.breakmodel_layers is not None): @@ -1123,6 +1134,10 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go globals()[m] = getattr(__import__("transformers"), m) except: pass + try: + from transformers.models.opt.modeling_opt import OPTDecoder + except: + pass import transformers.generation_utils from transformers import __version__ as transformers_version @@ -1253,8 +1268,10 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go input_ids.clamp_(max=self.config.vocab_size-1) if(hasattr(self, "transformer")): inputs_embeds = self.transformer.wte(input_ids) - else: + elif(not hasattr(model.model, "decoder")): inputs_embeds = self.model.embed_tokens(input_ids) + else: + inputs_embeds = self.model.decoder.embed_tokens(input_ids) if(vars.sp is not None): vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device) inputs_embeds = torch.where( @@ -1262,14 +1279,14 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go vars.sp[shifted_input_ids.clamp(min=0)], inputs_embeds, ) - if(not hasattr(self, "transformer")): + if(hasattr(self.model, "embed_scale")): inputs_embeds *= self.model.embed_scale kwargs['inputs_embeds'] = inputs_embeds return old_forward(self, *args, **kwargs) cls.forward = new_causallm_forward for cls in (GPT2LMHeadModel, GPTNeoForCausalLM): patch_causallm(cls) - for c in ("GPTJForCausalLM", "XGLMForCausalLM"): + for c in ("GPTJForCausalLM", "XGLMForCausalLM", "OPTForCausalLM"): try: patch_causallm(getattr(__import__("transformers"), c)) except: @@ -1430,12 +1447,18 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go def get_hidden_size_from_model(model): try: - return int(model.transformer.hidden_size) + return int(model.model.decoder.project_in.in_features) except: try: - return int(model.transformer.embed_dim) + return int(model.model.decoder.embed_tokens.out_features) except: - return int(model.lm_head.in_features) + try: + return int(model.transformer.hidden_size) + except: + try: + return int(model.transformer.embed_dim) + except: + return int(model.lm_head.in_features) def maybe_low_cpu_mem_usage() -> Dict[str, Any]: if(packaging.version.parse(transformers_version) < packaging.version.parse("4.11.0")): @@ -1490,7 +1513,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go import shutil shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_'))) print("\n", flush=True) - with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(model_config.num_layers if hasattr(model_config, "num_layers") else model_config.n_layer) if vars.lazy_load else None, dematerialized_modules=True): + with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(utils.num_layers(model_config)) if vars.lazy_load else None, dematerialized_modules=True): if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time lowmem = {} if(os.path.isdir(vars.custmodpth)): diff --git a/breakmodel.py b/breakmodel.py index 9818e6d9..262d39bd 100644 --- a/breakmodel.py +++ b/breakmodel.py @@ -633,11 +633,11 @@ def new_forward_xglm( layer_outputs = decoder_layer( hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states, attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask, - encoder_hidden_states=encoder_hidden_states.to(device) if encoder_hidden_states is not None else None, - encoder_attention_mask=encoder_attention_mask.to(device) if encoder_attention_mask is not None else None, - layer_head_mask=((head_mask[idx].to(device) if head_mask[idx] is not None else None) if head_mask is not None else None), + encoder_hidden_states=encoder_hidden_states.to(device) if breakmodel and encoder_hidden_states is not None else encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask.to(device) if breakmodel and encoder_attention_mask is not None else encoder_attention_mask, + layer_head_mask=((head_mask[idx].to(device) if breakmodel and head_mask[idx] is not None else head_mask[idx]) if head_mask is not None else None), cross_attn_layer_head_mask=( - (cross_attn_head_mask[idx].to(device) if cross_attn_head_mask[idx] is not None else None) if cross_attn_head_mask is not None else None + (cross_attn_head_mask[idx].to(device) if breakmodel and cross_attn_head_mask[idx] is not None else cross_attn_head_mask[idx]) if cross_attn_head_mask is not None else None ), past_key_value=tuple(v.to(device) for v in past_key_value if v is not None) if breakmodel and past_key_value is not None and i >= ram_blocks and len(past_key_value) and past_key_value[0].device.index != device else past_key_value, output_attentions=output_attentions, @@ -686,3 +686,177 @@ def new_forward_xglm( attentions=all_self_attns, cross_attentions=all_cross_attentions, ) + + +def new_forward_opt( + self, + input_ids=None, + attention_mask=None, + head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, +): + assert len(gpu_blocks) <= torch.cuda.device_count() + assert sum(gpu_blocks) <= len(self.layers) + ram_blocks = len(self.layers) - sum(gpu_blocks) + cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks)) + + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + if breakmodel: + input_ids = input_ids.to(primary_device) + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + if breakmodel: + inputs_embeds = inputs_embeds.to(primary_device) + if attention_mask is None: + attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device) + + positions = self.embed_positions(attention_mask)[:, past_key_values_length:, :] + if breakmodel: + positions = positions.to(primary_device) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + hidden_states = inputs_embeds + positions + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + if breakmodel and ram_blocks: + copystream = torch.cuda.Stream(device=primary_device, priority=-1) + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ["head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + i = idx + if breakmodel: + if i in range(ram_blocks): + index1 = (i+1)%ram_blocks + for param1,param2 in zip(self.layers[index1].parameters(),self.layers[(i-1)%ram_blocks].parameters()): + param1.data = param2.data + for param1,param2 in zip(self.layers[index1].parameters(),self.extrastorage[index1].parameters()): + with torch.cuda.stream(copystream): + torch.cuda.comm.broadcast(param2.data,out = [param1.data]) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + head_mask[idx] if head_mask is not None else None, + None, + ) + else: + if breakmodel: + device = primary_device if i < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, i - ram_blocks) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask, + layer_head_mask=((head_mask[idx].to(device) if breakmodel and head_mask[idx] is not None else head_mask[idx]) if head_mask is not None else None), + past_key_value=tuple(v.to(device) for v in past_key_value if v is not None) if breakmodel and past_key_value is not None and i >= ram_blocks and len(past_key_value) and past_key_value[0].device.index != device else past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if breakmodel: + if i in range(ram_blocks): + torch.cuda.synchronize() + torch.cuda.empty_cache() + + if breakmodel: + if ram_blocks: + del copystream + torch.cuda.empty_cache() + hidden_states = hidden_states.to(primary_device) + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + if breakmodel: + hidden_states = hidden_states.to(primary_device) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) diff --git a/utils.py b/utils.py index 0fdfa125..ea3cab94 100644 --- a/utils.py +++ b/utils.py @@ -135,6 +135,12 @@ def decodenewlines(txt): return txt.replace("", '\n') return txt +#==================================================================# +# Returns number of layers given an HF model config +#==================================================================# +def num_layers(config): + return config.num_layers if hasattr(config, "num_layers") else config.n_layer if hasattr(config, "n_layer") else config.num_hidden_layers + #==================================================================# # Downloads huggingface checkpoints using aria2c if possible #==================================================================# From 29bb3f569b550dbac4dfaa999c641115ecafec64 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Fri, 13 May 2022 01:37:17 -0400 Subject: [PATCH 5/8] Fix a bug in OPTForCausalLM where self.lm_head is the wrong size --- aiserver.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/aiserver.py b/aiserver.py index f988e6ba..102b24f5 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1293,6 +1293,25 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go pass + # Fix a bug in OPTForCausalLM where self.lm_head is the wrong size + if(transformers_version == "4.19.0"): + try: + from transformers import OPTForCausalLM, OPTModel + except ImportError: + pass + else: + # This is the same as the original __init__ but with + # config.hidden_size + # replaced with + # config.word_embed_proj_dim + def new_init(self, config): + super(OPTForCausalLM, self).__init__(config) + self.model = OPTModel(config) + self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + self.post_init() + OPTForCausalLM.__init__ = new_init + + # Patch transformers to use our custom logit warpers from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper From 55079f672a0002513c0c61caf7d2fb5f35289f1d Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Fri, 13 May 2022 01:51:55 -0400 Subject: [PATCH 6/8] Fix typo in soft prompt patching code --- aiserver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aiserver.py b/aiserver.py index 102b24f5..150c8b68 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1268,7 +1268,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go input_ids.clamp_(max=self.config.vocab_size-1) if(hasattr(self, "transformer")): inputs_embeds = self.transformer.wte(input_ids) - elif(not hasattr(model.model, "decoder")): + elif(not hasattr(self.model, "decoder")): inputs_embeds = self.model.embed_tokens(input_ids) else: inputs_embeds = self.model.decoder.embed_tokens(input_ids) @@ -1279,7 +1279,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go vars.sp[shifted_input_ids.clamp(min=0)], inputs_embeds, ) - if(hasattr(self.model, "embed_scale")): + if(hasattr(self, "model") and hasattr(self.model, "embed_scale")): inputs_embeds *= self.model.embed_scale kwargs['inputs_embeds'] = inputs_embeds return old_forward(self, *args, **kwargs) From 1200173386a0f5fb9eb06d30f3b3a7ab5ab1538c Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Fri, 13 May 2022 10:45:28 -0400 Subject: [PATCH 7/8] Custom badwords for OPT Generated using: ``` import transformers tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/opt-350m", fast=False) badwordsids_opt = [[v] for k, v in tokenizer.vocab.items() if any(c in k for c in "<>[]")] ``` --- aiserver.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/aiserver.py b/aiserver.py index 150c8b68..bbe5ec9a 100644 --- a/aiserver.py +++ b/aiserver.py @@ -240,6 +240,7 @@ class vars: # badwords = [] # Array of str/chr values that should be removed from output badwordsids = [[13460], [6880], [50256], [42496], [4613], [17414], [22039], [16410], [27], [29], [38430], [37922], [15913], [24618], [28725], [58], [47175], [36937], [26700], [12878], [16471], [37981], [5218], [29795], [13412], [45160], [3693], [49778], [4211], [20598], [36475], [33409], [44167], [32406], [29847], [29342], [42669], [685], [25787], [7359], [3784], [5320], [33994], [33490], [34516], [43734], [17635], [24293], [9959], [23785], [21737], [28401], [18161], [26358], [32509], [1279], [38155], [18189], [26894], [6927], [14610], [23834], [11037], [14631], [26933], [46904], [22330], [25915], [47934], [38214], [1875], [14692], [41832], [13163], [25970], [29565], [44926], [19841], [37250], [49029], [9609], [44438], [16791], [17816], [30109], [41888], [47527], [42924], [23984], [49074], [33717], [31161], [49082], [30138], [31175], [12240], [14804], [7131], [26076], [33250], [3556], [38381], [36338], [32756], [46581], [17912], [49146]] # Tokenized array of badwords used to prevent AI artifacting badwordsids_neox = [[0], [1], [44162], [9502], [12520], [31841], [36320], [49824], [34417], [6038], [34494], [24815], [26635], [24345], [3455], [28905], [44270], [17278], [32666], [46880], [7086], [43189], [37322], [17778], [20879], [49821], [3138], [14490], [4681], [21391], [26786], [43134], [9336], [683], [48074], [41256], [19181], [29650], [28532], [36487], [45114], [46275], [16445], [15104], [11337], [1168], [5647], [29], [27482], [44965], [43782], [31011], [42944], [47389], [6334], [17548], [38329], [32044], [35487], [2239], [34761], [7444], [1084], [12399], [18990], [17636], [39083], [1184], [35830], [28365], [16731], [43467], [47744], [1138], [16079], [40116], [45564], [18297], [42368], [5456], [18022], [42696], [34476], [23505], [23741], [39334], [37944], [45382], [38709], [33440], [26077], [43600], [34418], [36033], [6660], [48167], [48471], [15775], [19884], [41533], [1008], [31053], [36692], [46576], [20095], [20629], [31759], [46410], [41000], [13488], [30952], [39258], [16160], [27655], [22367], [42767], [43736], [49694], [13811], [12004], [46768], [6257], [37471], [5264], [44153], [33805], [20977], [21083], [25416], [14277], [31096], [42041], [18331], [33376], [22372], [46294], [28379], [38475], [1656], [5204], [27075], [50001], [16616], [11396], [7748], [48744], [35402], [28120], [41512], [4207], [43144], [14767], [15640], [16595], [41305], [44479], [38958], [18474], [22734], [30522], [46267], [60], [13976], [31830], [48701], [39822], [9014], [21966], [31422], [28052], [34607], [2479], [3851], [32214], [44082], [45507], [3001], [34368], [34758], [13380], [38363], [4299], [46802], [30996], [12630], [49236], [7082], [8795], [5218], [44740], [9686], [9983], [45301], [27114], [40125], [1570], [26997], [544], [5290], [49193], [23781], [14193], [40000], [2947], [43781], [9102], [48064], [42274], [18772], [49384], [9884], [45635], [43521], [31258], [32056], [47686], [21760], [13143], [10148], [26119], [44308], [31379], [36399], [23983], [46694], [36134], [8562], [12977], [35117], [28591], [49021], [47093], [28653], [29013], [46468], [8605], [7254], [25896], [5032], [8168], [36893], [38270], [20499], [27501], [34419], [29547], [28571], [36586], [20871], [30537], [26842], [21375], [31148], [27618], [33094], [3291], [31789], [28391], [870], [9793], [41361], [47916], [27468], [43856], [8850], [35237], [15707], [47552], [2730], [41449], [45488], [3073], [49806], [21938], [24430], [22747], [20924], [46145], [20481], [20197], [8239], [28231], [17987], [42804], [47269], [29972], [49884], [21382], [46295], [36676], [34616], [3921], [26991], [27720], [46265], [654], [9855], [40354], [5291], [34904], [44342], [2470], [14598], [880], [19282], [2498], [24237], [21431], [16369], [8994], [44524], [45662], [13663], [37077], [1447], [37786], [30863], [42854], [1019], [20322], [4398], [12159], [44072], [48664], [31547], [18736], [9259], [31], [16354], [21810], [4357], [37982], [5064], [2033], [32871], [47446], [62], [22158], [37387], [8743], [47007], [17981], [11049], [4622], [37916], [36786], [35138], [29925], [14157], [18095], [27829], [1181], [22226], [5709], [4725], [30189], [37014], [1254], [11380], [42989], [696], [24576], [39487], [30119], [1092], [8088], [2194], [9899], [14412], [21828], [3725], [13544], [5180], [44679], [34398], [3891], [28739], [14219], [37594], [49550], [11326], [6904], [17266], [5749], [10174], [23405], [9955], [38271], [41018], [13011], [48392], [36784], [24254], [21687], [23734], [5413], [41447], [45472], [10122], [17555], [15830], [47384], [12084], [31350], [47940], [11661], [27988], [45443], [905], [49651], [16614], [34993], [6781], [30803], [35869], [8001], [41604], [28118], [46462], [46762], [16262], [17281], [5774], [10943], [5013], [18257], [6750], [4713], [3951], [11899], [38791], [16943], [37596], [9318], [18413], [40473], [13208], [16375]] + badwordsids_opt = [[44717], [46613], [48513], [49923], [50185], [48755], [8488], [43303], [49659], [48601], [49817], [45405], [48742], [49925], [47720], [11227], [48937], [48784], [50017], [42248], [49310], [48082], [49895], [50025], [49092], [49007], [8061], [44226], [0], [742], [28578], [15698], [49784], [46679], [39365], [49281], [49609], [48081], [48906], [46161], [48554], [49670], [48677], [49721], [49632], [48610], [48462], [47457], [10975], [46077], [28696], [48709], [43839], [49798], [49154], [48203], [49625], [48395], [50155], [47161], [49095], [48833], [49420], [49666], [48443], [22176], [49242], [48651], [49138], [49750], [40389], [48021], [21838], [49070], [45333], [40862], [1], [49915], [33525], [49858], [50254], [44403], [48992], [48872], [46117], [49853], [47567], [50206], [41552], [50068], [48999], [49703], [49940], [49329], [47620], [49868], [49962], [2], [44082], [50236], [31274], [50260], [47052], [42645], [49177], [17523], [48691], [49900], [49069], [49358], [48794], [47529], [46479], [48457], [646], [49910], [48077], [48935], [46386], [48902], [49151], [48759], [49803], [45587], [48392], [47789], [48654], [49836], [49230], [48188], [50264], [46844], [44690], [48505], [50161], [27779], [49995], [41833], [50154], [49097], [48520], [50018], [8174], [50084], [49366], [49526], [50193], [7479], [49982], [3]] deletewi = None # Temporary storage for UID to delete wirmvwhtsp = False # Whether to remove leading whitespace from WI entries widepth = 3 # How many historical actions to scan for WI hits @@ -917,6 +918,9 @@ if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMe print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)") vars.model_type = "gpt_neo" + if(vars.model_type == "opt"): + vars.badwordsids = vars.badwordsids_opt + if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]): loadmodelsettings() loadsettings() @@ -1730,7 +1734,7 @@ else: if vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)): raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder") import tpu_mtj_backend - if(vars.model == "TPUMeshTransformerGPTNeoX"): + if(vars.model == "TPUMeshTransformerGPTNeoX" or vars.model_type == "opt"): tpu_mtj_backend.pad_token_id = 1 tpu_mtj_backend.vars = vars tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback From a051bf4397fb1e44d0ad9d5597442ea16b369d71 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Fri, 13 May 2022 10:45:57 -0400 Subject: [PATCH 8/8] OPT breakmodel bug fix --- breakmodel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/breakmodel.py b/breakmodel.py index 262d39bd..eb49e669 100644 --- a/breakmodel.py +++ b/breakmodel.py @@ -816,7 +816,7 @@ def new_forward_opt( if breakmodel: device = primary_device if i < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, i - ram_blocks) layer_outputs = decoder_layer( - hidden_states, + hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states, attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask, layer_head_mask=((head_mask[idx].to(device) if breakmodel and head_mask[idx] is not None else head_mask[idx]) if head_mask is not None else None), past_key_value=tuple(v.to(device) for v in past_key_value if v is not None) if breakmodel and past_key_value is not None and i >= ram_blocks and len(past_key_value) and past_key_value[0].device.index != device else past_key_value,