From 25c9be5d0248a607085f8feae6bb08543f1a8320 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Thu, 25 Nov 2021 18:09:16 -0500 Subject: [PATCH] Breakmodel support for GPTJModel --- aiserver.py | 4 +++- breakmodel.py | 30 ++++++++++-------------------- 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/aiserver.py b/aiserver.py index c550fc2c..ba65f79d 100644 --- a/aiserver.py +++ b/aiserver.py @@ -298,10 +298,12 @@ def device_config(model): model.transformer.ln_f.to(breakmodel.primary_device) if(hasattr(model, 'lm_head')): model.lm_head.to(breakmodel.primary_device) - if(not hasattr(model.config, 'rotary') or not model.config.rotary): + if(hasattr(model.transformer, 'wpe')): model.transformer.wpe.to(breakmodel.primary_device) gc.collect() GPTNeoModel.forward = breakmodel.new_forward + if("GPTJModel" in globals()): + GPTJModel.forward = breakmodel.new_forward generator = model.generate breakmodel.move_hidden_layers(model.transformer) diff --git a/breakmodel.py b/breakmodel.py index e1f3ce6e..d112e4b5 100644 --- a/breakmodel.py +++ b/breakmodel.py @@ -325,33 +325,27 @@ def new_forward( # Attention mask. if attention_mask is not None: assert batch_size > 0, "batch_size has to be defined and > 0" - global_attention_mask = attention_mask.view(batch_size, -1) + attention_mask = attention_mask.view(batch_size, -1) # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - global_attention_mask = global_attention_mask[:, None, None, :] + attention_mask = attention_mask[:, None, None, :] - # Since global_attention_mask is 1.0 for positions we want to attend and 0.0 for + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - global_attention_mask = global_attention_mask.to(dtype=self.dtype) # fp16 compatibility - global_attention_mask = (1.0 - global_attention_mask) * -10000.0 - else: - global_attention_mask = None - - # Local causal attention mask - batch_size, seq_length = input_shape - full_seq_length = seq_length + past_length + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x num_heads x N x N # head_mask has shape n_layer x batch x num_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.num_layers) + head_mask = self.get_head_mask(head_mask, getattr(self.config, "num_layers", None) or self.config.n_layer) if inputs_embeds is None: if breakmodel: @@ -367,7 +361,7 @@ def new_forward( inputs_embeds[:, pos:pos+emb.shape[1]] = emb offset += emb.shape[1] - if hasattr(self, 'rotary') and self.rotary: + if getattr(self, "wpe", None) is None: hidden_states = inputs_embeds else: if breakmodel: @@ -403,9 +397,6 @@ def new_forward( with torch.cuda.stream(copystream): torch.cuda.comm.broadcast(param2.data,out = [param1.data]) - attn_type = self.config.attention_layers[i] - attn_mask = global_attention_mask - if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states.cpu(),) @@ -413,8 +404,7 @@ def new_forward( if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -429,7 +419,7 @@ def new_forward( create_custom_forward(block), hidden_states, None, - attn_mask, + attention_mask, head_mask[i], ) else: @@ -438,7 +428,7 @@ def new_forward( outputs = block( hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states, layer_past=tuple(v.to(device) for v in layer_past if v is not None) if breakmodel and layer_past is not None and i >= ram_blocks and len(layer_past) and layer_past[0].device.index != device else layer_past, - attention_mask=attn_mask.to(device) if breakmodel and attn_mask is not None else attn_mask, + attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask, head_mask=head_mask[i].to(device) if breakmodel and head_mask[i] is not None else head_mask[i], use_cache=use_cache, output_attentions=output_attentions,