From e7f65cee09005729bb4c8a49ba6e69563d06fd38 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 1 Feb 2022 12:49:07 -0500 Subject: [PATCH] XGLM breakmodel --- aiserver.py | 38 +++++--- breakmodel.py | 261 +++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 260 insertions(+), 39 deletions(-) diff --git a/aiserver.py b/aiserver.py index 709760d2..3a3079d4 100644 --- a/aiserver.py +++ b/aiserver.py @@ -374,18 +374,29 @@ def device_config(model): return model.half().to('cpu') gc.collect() - model.transformer.wte.to(breakmodel.primary_device) - model.transformer.ln_f.to(breakmodel.primary_device) - if(hasattr(model, 'lm_head')): - model.lm_head.to(breakmodel.primary_device) - if(hasattr(model.transformer, 'wpe')): - model.transformer.wpe.to(breakmodel.primary_device) + if(hasattr(model, "transformer")): + model.transformer.wte.to(breakmodel.primary_device) + model.transformer.ln_f.to(breakmodel.primary_device) + if(hasattr(model, 'lm_head')): + model.lm_head.to(breakmodel.primary_device) + if(hasattr(model.transformer, 'wpe')): + model.transformer.wpe.to(breakmodel.primary_device) + else: + model.model.embed_tokens.to(breakmodel.primary_device) + model.model.layer_norm.to(breakmodel.primary_device) + model.model.lm_head.to(breakmodel.primary_device) + model.model.embed_positions.to(breakmodel.primary_device) gc.collect() - GPTNeoModel.forward = breakmodel.new_forward + GPTNeoModel.forward = breakmodel.new_forward_neo if("GPTJModel" in globals()): - GPTJModel.forward = breakmodel.new_forward + GPTJModel.forward = breakmodel.new_forward_neo + if("XGLMModel" in globals()): + XGLMModel.forward = breakmodel.new_forward_xglm generator = model.generate - breakmodel.move_hidden_layers(model.transformer) + if(hasattr(model, "transformer")): + breakmodel.move_hidden_layers(model.transformer) + else: + breakmodel.move_hidden_layers(model.model, model.model.layers) #==================================================================# # Allow the models to override some settings @@ -723,10 +734,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme if(not vars.noai): print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END)) from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer - try: - from transformers import GPTJModel - except: - pass + for m in ("GPTJModel", "XGLMModel"): + try: + globals()[m] = __import__("transformers." + m, fromlist=[...]) + except: + pass import transformers.generation_utils from transformers import __version__ as transformers_version diff --git a/breakmodel.py b/breakmodel.py index 087a112a..9818e6d9 100644 --- a/breakmodel.py +++ b/breakmodel.py @@ -212,14 +212,17 @@ Copyright 2018 The Hugging Face team import torch +from torch import nn import torch.cuda.comm import copy import gc import sys import itertools import bisect +import random +from typing import Optional -from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions from transformers.utils import logging logger = logging.get_logger(__name__) @@ -230,22 +233,40 @@ gpu_blocks = [] primary_device = 0 -def move_hidden_layers(transformer): +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + + +def move_hidden_layers(transformer, h=None): + if h is None: + h = transformer.h + assert len(gpu_blocks) <= torch.cuda.device_count() - assert sum(gpu_blocks) <= len(transformer.h) - ram_blocks = len(transformer.h) - sum(gpu_blocks) + assert sum(gpu_blocks) <= len(h) + ram_blocks = len(h) - sum(gpu_blocks) transformer.extrastorage = {} torch.cuda.empty_cache() able_to_pin_layers = True for i in range(ram_blocks): - transformer.h[i].to("cpu") - transformer.extrastorage[i] = copy.deepcopy(transformer.h[i]) + h[i].to("cpu") + transformer.extrastorage[i] = copy.deepcopy(h[i]) smalltensor = torch.tensor(0).to(primary_device) - for param1 in transformer.h[i].parameters(): + for param1 in h[i].parameters(): param1.data = smalltensor - transformer.h[i].to(primary_device) + h[i].to(primary_device) for param in transformer.extrastorage[i].parameters(): param.requires_grad = False param.data = param.data.detach() @@ -259,34 +280,34 @@ def move_hidden_layers(transformer): torch.cuda.empty_cache() if ram_blocks: - for param1,param2 in zip(transformer.h[0].parameters(),transformer.extrastorage[0].parameters()): + for param1,param2 in zip(h[0].parameters(),transformer.extrastorage[0].parameters()): param1.data = param2.data.to(primary_device, non_blocking=False).detach() - for param1,param2 in zip(transformer.h[ram_blocks-1].parameters(),transformer.extrastorage[ram_blocks-1].parameters()): + for param1,param2 in zip(h[ram_blocks-1].parameters(),transformer.extrastorage[ram_blocks-1].parameters()): param1.data = param2.data.to(primary_device, non_blocking=False).detach() i = ram_blocks for j in range(len(gpu_blocks)): for _ in range(gpu_blocks[j]): - transformer.h[i].to(j) + h[i].to(j) i += 1 -def new_forward( - self, - input_ids=None, - past_key_values=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - embs=None, - ): +def new_forward_neo( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + embs=None, +): assert len(gpu_blocks) <= torch.cuda.device_count() assert sum(gpu_blocks) <= len(self.h) ram_blocks = len(self.h) - sum(gpu_blocks) @@ -477,3 +498,191 @@ def new_forward( hidden_states=all_hidden_states, attentions=all_self_attentions, ) + + +def new_forward_xglm( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, +): + assert len(gpu_blocks) <= torch.cuda.device_count() + assert sum(gpu_blocks) <= len(self.layers) + ram_blocks = len(self.layers) - sum(gpu_blocks) + cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks)) + + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + if breakmodel: + input_ids = input_ids.to(primary_device) + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + + # embed positions + if breakmodel: + inputs_embeds = inputs_embeds.to(primary_device) + positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length) + if breakmodel: + positions = positions.to(primary_device) + + hidden_states = inputs_embeds + positions + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + if breakmodel and ram_blocks: + copystream = torch.cuda.Stream(device=primary_device, priority=-1) + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == ( + len(self.layers) + ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + for idx, decoder_layer in enumerate(self.layers): + i = idx + if breakmodel: + if i in range(ram_blocks): + index1 = (i+1)%ram_blocks + for param1,param2 in zip(self.layers[index1].parameters(),self.layers[(i-1)%ram_blocks].parameters()): + param1.data = param2.data + for param1,param2 in zip(self.layers[index1].parameters(),self.extrastorage[index1].parameters()): + with torch.cuda.stream(copystream): + torch.cuda.comm.broadcast(param2.data,out = [param1.data]) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + "`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, + ) + else: + if breakmodel: + device = primary_device if i < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, i - ram_blocks) + layer_outputs = decoder_layer( + hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states, + attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask, + encoder_hidden_states=encoder_hidden_states.to(device) if encoder_hidden_states is not None else None, + encoder_attention_mask=encoder_attention_mask.to(device) if encoder_attention_mask is not None else None, + layer_head_mask=((head_mask[idx].to(device) if head_mask[idx] is not None else None) if head_mask is not None else None), + cross_attn_layer_head_mask=( + (cross_attn_head_mask[idx].to(device) if cross_attn_head_mask[idx] is not None else None) if cross_attn_head_mask is not None else None + ), + past_key_value=tuple(v.to(device) for v in past_key_value if v is not None) if breakmodel and past_key_value is not None and i >= ram_blocks and len(past_key_value) and past_key_value[0].device.index != device else past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + if breakmodel: + if i in range(ram_blocks): + torch.cuda.synchronize() + torch.cuda.empty_cache() + + if breakmodel: + if ram_blocks: + del copystream + torch.cuda.empty_cache() + hidden_states = hidden_states.to(primary_device) + hidden_states = self.layer_norm(hidden_states) + if breakmodel: + hidden_states = hidden_states.to(primary_device) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + )