XGLM breakmodel
This commit is contained in:
parent
c14e6fe5d2
commit
e7f65cee09
18
aiserver.py
18
aiserver.py
|
@ -374,18 +374,29 @@ def device_config(model):
|
||||||
return
|
return
|
||||||
model.half().to('cpu')
|
model.half().to('cpu')
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
if(hasattr(model, "transformer")):
|
||||||
model.transformer.wte.to(breakmodel.primary_device)
|
model.transformer.wte.to(breakmodel.primary_device)
|
||||||
model.transformer.ln_f.to(breakmodel.primary_device)
|
model.transformer.ln_f.to(breakmodel.primary_device)
|
||||||
if(hasattr(model, 'lm_head')):
|
if(hasattr(model, 'lm_head')):
|
||||||
model.lm_head.to(breakmodel.primary_device)
|
model.lm_head.to(breakmodel.primary_device)
|
||||||
if(hasattr(model.transformer, 'wpe')):
|
if(hasattr(model.transformer, 'wpe')):
|
||||||
model.transformer.wpe.to(breakmodel.primary_device)
|
model.transformer.wpe.to(breakmodel.primary_device)
|
||||||
|
else:
|
||||||
|
model.model.embed_tokens.to(breakmodel.primary_device)
|
||||||
|
model.model.layer_norm.to(breakmodel.primary_device)
|
||||||
|
model.model.lm_head.to(breakmodel.primary_device)
|
||||||
|
model.model.embed_positions.to(breakmodel.primary_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
GPTNeoModel.forward = breakmodel.new_forward
|
GPTNeoModel.forward = breakmodel.new_forward_neo
|
||||||
if("GPTJModel" in globals()):
|
if("GPTJModel" in globals()):
|
||||||
GPTJModel.forward = breakmodel.new_forward
|
GPTJModel.forward = breakmodel.new_forward_neo
|
||||||
|
if("XGLMModel" in globals()):
|
||||||
|
XGLMModel.forward = breakmodel.new_forward_xglm
|
||||||
generator = model.generate
|
generator = model.generate
|
||||||
|
if(hasattr(model, "transformer")):
|
||||||
breakmodel.move_hidden_layers(model.transformer)
|
breakmodel.move_hidden_layers(model.transformer)
|
||||||
|
else:
|
||||||
|
breakmodel.move_hidden_layers(model.model, model.model.layers)
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Allow the models to override some settings
|
# Allow the models to override some settings
|
||||||
|
@ -723,8 +734,9 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
if(not vars.noai):
|
if(not vars.noai):
|
||||||
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||||
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
|
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
|
||||||
|
for m in ("GPTJModel", "XGLMModel"):
|
||||||
try:
|
try:
|
||||||
from transformers import GPTJModel
|
globals()[m] = __import__("transformers." + m, fromlist=[...])
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
import transformers.generation_utils
|
import transformers.generation_utils
|
||||||
|
|
235
breakmodel.py
235
breakmodel.py
|
@ -212,14 +212,17 @@ Copyright 2018 The Hugging Face team
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
import torch.cuda.comm
|
import torch.cuda.comm
|
||||||
import copy
|
import copy
|
||||||
import gc
|
import gc
|
||||||
import sys
|
import sys
|
||||||
import itertools
|
import itertools
|
||||||
import bisect
|
import bisect
|
||||||
|
import random
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions
|
||||||
|
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
@ -230,22 +233,40 @@ gpu_blocks = []
|
||||||
primary_device = 0
|
primary_device = 0
|
||||||
|
|
||||||
|
|
||||||
def move_hidden_layers(transformer):
|
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
||||||
|
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||||
|
"""
|
||||||
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||||
|
"""
|
||||||
|
bsz, src_len = mask.size()
|
||||||
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||||
|
|
||||||
|
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||||
|
|
||||||
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
|
||||||
|
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
||||||
|
|
||||||
|
|
||||||
|
def move_hidden_layers(transformer, h=None):
|
||||||
|
if h is None:
|
||||||
|
h = transformer.h
|
||||||
|
|
||||||
assert len(gpu_blocks) <= torch.cuda.device_count()
|
assert len(gpu_blocks) <= torch.cuda.device_count()
|
||||||
assert sum(gpu_blocks) <= len(transformer.h)
|
assert sum(gpu_blocks) <= len(h)
|
||||||
ram_blocks = len(transformer.h) - sum(gpu_blocks)
|
ram_blocks = len(h) - sum(gpu_blocks)
|
||||||
|
|
||||||
transformer.extrastorage = {}
|
transformer.extrastorage = {}
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
able_to_pin_layers = True
|
able_to_pin_layers = True
|
||||||
for i in range(ram_blocks):
|
for i in range(ram_blocks):
|
||||||
transformer.h[i].to("cpu")
|
h[i].to("cpu")
|
||||||
transformer.extrastorage[i] = copy.deepcopy(transformer.h[i])
|
transformer.extrastorage[i] = copy.deepcopy(h[i])
|
||||||
smalltensor = torch.tensor(0).to(primary_device)
|
smalltensor = torch.tensor(0).to(primary_device)
|
||||||
for param1 in transformer.h[i].parameters():
|
for param1 in h[i].parameters():
|
||||||
param1.data = smalltensor
|
param1.data = smalltensor
|
||||||
transformer.h[i].to(primary_device)
|
h[i].to(primary_device)
|
||||||
for param in transformer.extrastorage[i].parameters():
|
for param in transformer.extrastorage[i].parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
param.data = param.data.detach()
|
param.data = param.data.detach()
|
||||||
|
@ -259,20 +280,20 @@ def move_hidden_layers(transformer):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
if ram_blocks:
|
if ram_blocks:
|
||||||
for param1,param2 in zip(transformer.h[0].parameters(),transformer.extrastorage[0].parameters()):
|
for param1,param2 in zip(h[0].parameters(),transformer.extrastorage[0].parameters()):
|
||||||
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
|
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
|
||||||
|
|
||||||
for param1,param2 in zip(transformer.h[ram_blocks-1].parameters(),transformer.extrastorage[ram_blocks-1].parameters()):
|
for param1,param2 in zip(h[ram_blocks-1].parameters(),transformer.extrastorage[ram_blocks-1].parameters()):
|
||||||
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
|
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
|
||||||
|
|
||||||
i = ram_blocks
|
i = ram_blocks
|
||||||
for j in range(len(gpu_blocks)):
|
for j in range(len(gpu_blocks)):
|
||||||
for _ in range(gpu_blocks[j]):
|
for _ in range(gpu_blocks[j]):
|
||||||
transformer.h[i].to(j)
|
h[i].to(j)
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
|
|
||||||
def new_forward(
|
def new_forward_neo(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
|
@ -286,7 +307,7 @@ def new_forward(
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
embs=None,
|
embs=None,
|
||||||
):
|
):
|
||||||
assert len(gpu_blocks) <= torch.cuda.device_count()
|
assert len(gpu_blocks) <= torch.cuda.device_count()
|
||||||
assert sum(gpu_blocks) <= len(self.h)
|
assert sum(gpu_blocks) <= len(self.h)
|
||||||
ram_blocks = len(self.h) - sum(gpu_blocks)
|
ram_blocks = len(self.h) - sum(gpu_blocks)
|
||||||
|
@ -477,3 +498,191 @@ def new_forward(
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attentions,
|
attentions=all_self_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def new_forward_xglm(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
):
|
||||||
|
assert len(gpu_blocks) <= torch.cuda.device_count()
|
||||||
|
assert sum(gpu_blocks) <= len(self.layers)
|
||||||
|
ram_blocks = len(self.layers) - sum(gpu_blocks)
|
||||||
|
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
|
||||||
|
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
# past_key_values_length
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
if breakmodel:
|
||||||
|
input_ids = input_ids.to(primary_device)
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||||
|
|
||||||
|
attention_mask = self._prepare_decoder_attention_mask(
|
||||||
|
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||||
|
)
|
||||||
|
|
||||||
|
# expand encoder attention mask
|
||||||
|
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
|
||||||
|
|
||||||
|
# embed positions
|
||||||
|
if breakmodel:
|
||||||
|
inputs_embeds = inputs_embeds.to(primary_device)
|
||||||
|
positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length)
|
||||||
|
if breakmodel:
|
||||||
|
positions = positions.to(primary_device)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds + positions
|
||||||
|
|
||||||
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
if breakmodel and ram_blocks:
|
||||||
|
copystream = torch.cuda.Stream(device=primary_device, priority=-1)
|
||||||
|
|
||||||
|
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||||||
|
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||||||
|
if attn_mask is not None:
|
||||||
|
assert attn_mask.size()[0] == (
|
||||||
|
len(self.layers)
|
||||||
|
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||||
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
|
i = idx
|
||||||
|
if breakmodel:
|
||||||
|
if i in range(ram_blocks):
|
||||||
|
index1 = (i+1)%ram_blocks
|
||||||
|
for param1,param2 in zip(self.layers[index1].parameters(),self.layers[(i-1)%ram_blocks].parameters()):
|
||||||
|
param1.data = param2.data
|
||||||
|
for param1,param2 in zip(self.layers[index1].parameters(),self.extrastorage[index1].parameters()):
|
||||||
|
with torch.cuda.stream(copystream):
|
||||||
|
torch.cuda.comm.broadcast(param2.data,out = [param1.data])
|
||||||
|
|
||||||
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
dropout_probability = random.uniform(0, 1)
|
||||||
|
if self.training and (dropout_probability < self.layerdrop):
|
||||||
|
continue
|
||||||
|
|
||||||
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
logger.warning(
|
||||||
|
"`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
# None for past_key_value
|
||||||
|
return module(*inputs, output_attentions, use_cache)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(decoder_layer),
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
head_mask[idx] if head_mask is not None else None,
|
||||||
|
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if breakmodel:
|
||||||
|
device = primary_device if i < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, i - ram_blocks)
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states,
|
||||||
|
attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states.to(device) if encoder_hidden_states is not None else None,
|
||||||
|
encoder_attention_mask=encoder_attention_mask.to(device) if encoder_attention_mask is not None else None,
|
||||||
|
layer_head_mask=((head_mask[idx].to(device) if head_mask[idx] is not None else None) if head_mask is not None else None),
|
||||||
|
cross_attn_layer_head_mask=(
|
||||||
|
(cross_attn_head_mask[idx].to(device) if cross_attn_head_mask[idx] is not None else None) if cross_attn_head_mask is not None else None
|
||||||
|
),
|
||||||
|
past_key_value=tuple(v.to(device) for v in past_key_value if v is not None) if breakmodel and past_key_value is not None and i >= ram_blocks and len(past_key_value) and past_key_value[0].device.index != device else past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
all_cross_attentions += (layer_outputs[2],)
|
||||||
|
|
||||||
|
if breakmodel:
|
||||||
|
if i in range(ram_blocks):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if breakmodel:
|
||||||
|
if ram_blocks:
|
||||||
|
del copystream
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
hidden_states = hidden_states.to(primary_device)
|
||||||
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
|
if breakmodel:
|
||||||
|
hidden_states = hidden_states.to(primary_device)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
cross_attentions=all_cross_attentions,
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue