XGLM breakmodel

This commit is contained in:
Gnome Ann 2022-02-01 12:49:07 -05:00
parent c14e6fe5d2
commit e7f65cee09
2 changed files with 260 additions and 39 deletions

View File

@ -374,18 +374,29 @@ def device_config(model):
return
model.half().to('cpu')
gc.collect()
model.transformer.wte.to(breakmodel.primary_device)
model.transformer.ln_f.to(breakmodel.primary_device)
if(hasattr(model, 'lm_head')):
model.lm_head.to(breakmodel.primary_device)
if(hasattr(model.transformer, 'wpe')):
model.transformer.wpe.to(breakmodel.primary_device)
if(hasattr(model, "transformer")):
model.transformer.wte.to(breakmodel.primary_device)
model.transformer.ln_f.to(breakmodel.primary_device)
if(hasattr(model, 'lm_head')):
model.lm_head.to(breakmodel.primary_device)
if(hasattr(model.transformer, 'wpe')):
model.transformer.wpe.to(breakmodel.primary_device)
else:
model.model.embed_tokens.to(breakmodel.primary_device)
model.model.layer_norm.to(breakmodel.primary_device)
model.model.lm_head.to(breakmodel.primary_device)
model.model.embed_positions.to(breakmodel.primary_device)
gc.collect()
GPTNeoModel.forward = breakmodel.new_forward
GPTNeoModel.forward = breakmodel.new_forward_neo
if("GPTJModel" in globals()):
GPTJModel.forward = breakmodel.new_forward
GPTJModel.forward = breakmodel.new_forward_neo
if("XGLMModel" in globals()):
XGLMModel.forward = breakmodel.new_forward_xglm
generator = model.generate
breakmodel.move_hidden_layers(model.transformer)
if(hasattr(model, "transformer")):
breakmodel.move_hidden_layers(model.transformer)
else:
breakmodel.move_hidden_layers(model.model, model.model.layers)
#==================================================================#
# Allow the models to override some settings
@ -723,10 +734,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
if(not vars.noai):
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
try:
from transformers import GPTJModel
except:
pass
for m in ("GPTJModel", "XGLMModel"):
try:
globals()[m] = __import__("transformers." + m, fromlist=[...])
except:
pass
import transformers.generation_utils
from transformers import __version__ as transformers_version

View File

@ -212,14 +212,17 @@ Copyright 2018 The Hugging Face team
import torch
from torch import nn
import torch.cuda.comm
import copy
import gc
import sys
import itertools
import bisect
import random
from typing import Optional
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions
from transformers.utils import logging
logger = logging.get_logger(__name__)
@ -230,22 +233,40 @@ gpu_blocks = []
primary_device = 0
def move_hidden_layers(transformer):
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
def move_hidden_layers(transformer, h=None):
if h is None:
h = transformer.h
assert len(gpu_blocks) <= torch.cuda.device_count()
assert sum(gpu_blocks) <= len(transformer.h)
ram_blocks = len(transformer.h) - sum(gpu_blocks)
assert sum(gpu_blocks) <= len(h)
ram_blocks = len(h) - sum(gpu_blocks)
transformer.extrastorage = {}
torch.cuda.empty_cache()
able_to_pin_layers = True
for i in range(ram_blocks):
transformer.h[i].to("cpu")
transformer.extrastorage[i] = copy.deepcopy(transformer.h[i])
h[i].to("cpu")
transformer.extrastorage[i] = copy.deepcopy(h[i])
smalltensor = torch.tensor(0).to(primary_device)
for param1 in transformer.h[i].parameters():
for param1 in h[i].parameters():
param1.data = smalltensor
transformer.h[i].to(primary_device)
h[i].to(primary_device)
for param in transformer.extrastorage[i].parameters():
param.requires_grad = False
param.data = param.data.detach()
@ -259,34 +280,34 @@ def move_hidden_layers(transformer):
torch.cuda.empty_cache()
if ram_blocks:
for param1,param2 in zip(transformer.h[0].parameters(),transformer.extrastorage[0].parameters()):
for param1,param2 in zip(h[0].parameters(),transformer.extrastorage[0].parameters()):
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
for param1,param2 in zip(transformer.h[ram_blocks-1].parameters(),transformer.extrastorage[ram_blocks-1].parameters()):
for param1,param2 in zip(h[ram_blocks-1].parameters(),transformer.extrastorage[ram_blocks-1].parameters()):
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
i = ram_blocks
for j in range(len(gpu_blocks)):
for _ in range(gpu_blocks[j]):
transformer.h[i].to(j)
h[i].to(j)
i += 1
def new_forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
embs=None,
):
def new_forward_neo(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
embs=None,
):
assert len(gpu_blocks) <= torch.cuda.device_count()
assert sum(gpu_blocks) <= len(self.h)
ram_blocks = len(self.h) - sum(gpu_blocks)
@ -477,3 +498,191 @@ def new_forward(
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
def new_forward_xglm(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
assert len(gpu_blocks) <= torch.cuda.device_count()
assert sum(gpu_blocks) <= len(self.layers)
ram_blocks = len(self.layers) - sum(gpu_blocks)
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
if breakmodel:
input_ids = input_ids.to(primary_device)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
# embed positions
if breakmodel:
inputs_embeds = inputs_embeds.to(primary_device)
positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length)
if breakmodel:
positions = positions.to(primary_device)
hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None
if breakmodel and ram_blocks:
copystream = torch.cuda.Stream(device=primary_device, priority=-1)
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
assert attn_mask.size()[0] == (
len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers):
i = idx
if breakmodel:
if i in range(ram_blocks):
index1 = (i+1)%ram_blocks
for param1,param2 in zip(self.layers[index1].parameters(),self.layers[(i-1)%ram_blocks].parameters()):
param1.data = param2.data
for param1,param2 in zip(self.layers[index1].parameters(),self.extrastorage[index1].parameters()):
with torch.cuda.stream(copystream):
torch.cuda.comm.broadcast(param2.data,out = [param1.data])
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
)
else:
if breakmodel:
device = primary_device if i < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, i - ram_blocks)
layer_outputs = decoder_layer(
hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states,
attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask,
encoder_hidden_states=encoder_hidden_states.to(device) if encoder_hidden_states is not None else None,
encoder_attention_mask=encoder_attention_mask.to(device) if encoder_attention_mask is not None else None,
layer_head_mask=((head_mask[idx].to(device) if head_mask[idx] is not None else None) if head_mask is not None else None),
cross_attn_layer_head_mask=(
(cross_attn_head_mask[idx].to(device) if cross_attn_head_mask[idx] is not None else None) if cross_attn_head_mask is not None else None
),
past_key_value=tuple(v.to(device) for v in past_key_value if v is not None) if breakmodel and past_key_value is not None and i >= ram_blocks and len(past_key_value) and past_key_value[0].device.index != device else past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
if breakmodel:
if i in range(ram_blocks):
torch.cuda.synchronize()
torch.cuda.empty_cache()
if breakmodel:
if ram_blocks:
del copystream
torch.cuda.empty_cache()
hidden_states = hidden_states.to(primary_device)
hidden_states = self.layer_norm(hidden_states)
if breakmodel:
hidden_states = hidden_states.to(primary_device)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)