mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
OPT breakmodel
This commit is contained in:
182
breakmodel.py
182
breakmodel.py
@ -633,11 +633,11 @@ def new_forward_xglm(
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states,
|
||||
attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states.to(device) if encoder_hidden_states is not None else None,
|
||||
encoder_attention_mask=encoder_attention_mask.to(device) if encoder_attention_mask is not None else None,
|
||||
layer_head_mask=((head_mask[idx].to(device) if head_mask[idx] is not None else None) if head_mask is not None else None),
|
||||
encoder_hidden_states=encoder_hidden_states.to(device) if breakmodel and encoder_hidden_states is not None else encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask.to(device) if breakmodel and encoder_attention_mask is not None else encoder_attention_mask,
|
||||
layer_head_mask=((head_mask[idx].to(device) if breakmodel and head_mask[idx] is not None else head_mask[idx]) if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
(cross_attn_head_mask[idx].to(device) if cross_attn_head_mask[idx] is not None else None) if cross_attn_head_mask is not None else None
|
||||
(cross_attn_head_mask[idx].to(device) if breakmodel and cross_attn_head_mask[idx] is not None else cross_attn_head_mask[idx]) if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=tuple(v.to(device) for v in past_key_value if v is not None) if breakmodel and past_key_value is not None and i >= ram_blocks and len(past_key_value) and past_key_value[0].device.index != device else past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
@ -686,3 +686,177 @@ def new_forward_xglm(
|
||||
attentions=all_self_attns,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
def new_forward_opt(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
assert len(gpu_blocks) <= torch.cuda.device_count()
|
||||
assert sum(gpu_blocks) <= len(self.layers)
|
||||
ram_blocks = len(self.layers) - sum(gpu_blocks)
|
||||
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
|
||||
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
if breakmodel:
|
||||
input_ids = input_ids.to(primary_device)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# embed positions
|
||||
if breakmodel:
|
||||
inputs_embeds = inputs_embeds.to(primary_device)
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
|
||||
|
||||
positions = self.embed_positions(attention_mask)[:, past_key_values_length:, :]
|
||||
if breakmodel:
|
||||
positions = positions.to(primary_device)
|
||||
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
if self.project_in is not None:
|
||||
inputs_embeds = self.project_in(inputs_embeds)
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
if breakmodel and ram_blocks:
|
||||
copystream = torch.cuda.Stream(device=primary_device, priority=-1)
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
|
||||
if attn_mask is not None:
|
||||
if attn_mask.size()[0] != (len(self.layers)):
|
||||
raise ValueError(
|
||||
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
||||
f" {head_mask.size()[0]}."
|
||||
)
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
i = idx
|
||||
if breakmodel:
|
||||
if i in range(ram_blocks):
|
||||
index1 = (i+1)%ram_blocks
|
||||
for param1,param2 in zip(self.layers[index1].parameters(),self.layers[(i-1)%ram_blocks].parameters()):
|
||||
param1.data = param2.data
|
||||
for param1,param2 in zip(self.layers[index1].parameters(),self.extrastorage[index1].parameters()):
|
||||
with torch.cuda.stream(copystream):
|
||||
torch.cuda.comm.broadcast(param2.data,out = [param1.data])
|
||||
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
dropout_probability = random.uniform(0, 1)
|
||||
if self.training and (dropout_probability < self.layerdrop):
|
||||
continue
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
if breakmodel:
|
||||
device = primary_device if i < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, i - ram_blocks)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask,
|
||||
layer_head_mask=((head_mask[idx].to(device) if breakmodel and head_mask[idx] is not None else head_mask[idx]) if head_mask is not None else None),
|
||||
past_key_value=tuple(v.to(device) for v in past_key_value if v is not None) if breakmodel and past_key_value is not None and i >= ram_blocks and len(past_key_value) and past_key_value[0].device.index != device else past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if breakmodel:
|
||||
if i in range(ram_blocks):
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if breakmodel:
|
||||
if ram_blocks:
|
||||
del copystream
|
||||
torch.cuda.empty_cache()
|
||||
hidden_states = hidden_states.to(primary_device)
|
||||
if self.project_out is not None:
|
||||
hidden_states = self.project_out(hidden_states)
|
||||
if breakmodel:
|
||||
hidden_states = hidden_states.to(primary_device)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
Reference in New Issue
Block a user