mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-01-22 05:10:35 +01:00
Breakmodel support for GPTJModel
This commit is contained in:
parent
f8bcc3411b
commit
25c9be5d02
@ -298,10 +298,12 @@ def device_config(model):
|
|||||||
model.transformer.ln_f.to(breakmodel.primary_device)
|
model.transformer.ln_f.to(breakmodel.primary_device)
|
||||||
if(hasattr(model, 'lm_head')):
|
if(hasattr(model, 'lm_head')):
|
||||||
model.lm_head.to(breakmodel.primary_device)
|
model.lm_head.to(breakmodel.primary_device)
|
||||||
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
|
if(hasattr(model.transformer, 'wpe')):
|
||||||
model.transformer.wpe.to(breakmodel.primary_device)
|
model.transformer.wpe.to(breakmodel.primary_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
GPTNeoModel.forward = breakmodel.new_forward
|
GPTNeoModel.forward = breakmodel.new_forward
|
||||||
|
if("GPTJModel" in globals()):
|
||||||
|
GPTJModel.forward = breakmodel.new_forward
|
||||||
generator = model.generate
|
generator = model.generate
|
||||||
breakmodel.move_hidden_layers(model.transformer)
|
breakmodel.move_hidden_layers(model.transformer)
|
||||||
|
|
||||||
|
@ -325,33 +325,27 @@ def new_forward(
|
|||||||
# Attention mask.
|
# Attention mask.
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert batch_size > 0, "batch_size has to be defined and > 0"
|
assert batch_size > 0, "batch_size has to be defined and > 0"
|
||||||
global_attention_mask = attention_mask.view(batch_size, -1)
|
attention_mask = attention_mask.view(batch_size, -1)
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
# We create a 3D attention mask from a 2D tensor mask.
|
||||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||||
# this attention mask is more simple than the triangular masking of causal attention
|
# this attention mask is more simple than the triangular masking of causal attention
|
||||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||||
global_attention_mask = global_attention_mask[:, None, None, :]
|
attention_mask = attention_mask[:, None, None, :]
|
||||||
|
|
||||||
# Since global_attention_mask is 1.0 for positions we want to attend and 0.0 for
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
# positions we want to attend and -10000.0 for masked positions.
|
# positions we want to attend and -10000.0 for masked positions.
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
# effectively the same as removing these entirely.
|
# effectively the same as removing these entirely.
|
||||||
global_attention_mask = global_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||||
global_attention_mask = (1.0 - global_attention_mask) * -10000.0
|
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||||
else:
|
|
||||||
global_attention_mask = None
|
|
||||||
|
|
||||||
# Local causal attention mask
|
|
||||||
batch_size, seq_length = input_shape
|
|
||||||
full_seq_length = seq_length + past_length
|
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
# attention_probs has shape bsz x num_heads x N x N
|
# attention_probs has shape bsz x num_heads x N x N
|
||||||
# head_mask has shape n_layer x batch x num_heads x N x N
|
# head_mask has shape n_layer x batch x num_heads x N x N
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
head_mask = self.get_head_mask(head_mask, getattr(self.config, "num_layers", None) or self.config.n_layer)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
if breakmodel:
|
if breakmodel:
|
||||||
@ -367,7 +361,7 @@ def new_forward(
|
|||||||
inputs_embeds[:, pos:pos+emb.shape[1]] = emb
|
inputs_embeds[:, pos:pos+emb.shape[1]] = emb
|
||||||
offset += emb.shape[1]
|
offset += emb.shape[1]
|
||||||
|
|
||||||
if hasattr(self, 'rotary') and self.rotary:
|
if getattr(self, "wpe", None) is None:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
else:
|
else:
|
||||||
if breakmodel:
|
if breakmodel:
|
||||||
@ -403,9 +397,6 @@ def new_forward(
|
|||||||
with torch.cuda.stream(copystream):
|
with torch.cuda.stream(copystream):
|
||||||
torch.cuda.comm.broadcast(param2.data,out = [param1.data])
|
torch.cuda.comm.broadcast(param2.data,out = [param1.data])
|
||||||
|
|
||||||
attn_type = self.config.attention_layers[i]
|
|
||||||
attn_mask = global_attention_mask
|
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states.cpu(),)
|
all_hidden_states = all_hidden_states + (hidden_states.cpu(),)
|
||||||
|
|
||||||
@ -413,8 +404,7 @@ def new_forward(
|
|||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
"`use_cache=False`..."
|
|
||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
@ -429,7 +419,7 @@ def new_forward(
|
|||||||
create_custom_forward(block),
|
create_custom_forward(block),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
None,
|
None,
|
||||||
attn_mask,
|
attention_mask,
|
||||||
head_mask[i],
|
head_mask[i],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -438,7 +428,7 @@ def new_forward(
|
|||||||
outputs = block(
|
outputs = block(
|
||||||
hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states,
|
hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states,
|
||||||
layer_past=tuple(v.to(device) for v in layer_past if v is not None) if breakmodel and layer_past is not None and i >= ram_blocks and len(layer_past) and layer_past[0].device.index != device else layer_past,
|
layer_past=tuple(v.to(device) for v in layer_past if v is not None) if breakmodel and layer_past is not None and i >= ram_blocks and len(layer_past) and layer_past[0].device.index != device else layer_past,
|
||||||
attention_mask=attn_mask.to(device) if breakmodel and attn_mask is not None else attn_mask,
|
attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask,
|
||||||
head_mask=head_mask[i].to(device) if breakmodel and head_mask[i] is not None else head_mask[i],
|
head_mask=head_mask[i].to(device) if breakmodel and head_mask[i] is not None else head_mask[i],
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
Loading…
Reference in New Issue
Block a user