OPT breakmodel bug fix

This commit is contained in:
Gnome Ann 2022-05-13 10:45:57 -04:00
parent 1200173386
commit a051bf4397
1 changed files with 1 additions and 1 deletions

View File

@ -816,7 +816,7 @@ def new_forward_opt(
if breakmodel: if breakmodel:
device = primary_device if i < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, i - ram_blocks) device = primary_device if i < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, i - ram_blocks)
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states,
attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask, attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask,
layer_head_mask=((head_mask[idx].to(device) if breakmodel and head_mask[idx] is not None else head_mask[idx]) if head_mask is not None else None), layer_head_mask=((head_mask[idx].to(device) if breakmodel and head_mask[idx] is not None else head_mask[idx]) if head_mask is not None else None),
past_key_value=tuple(v.to(device) for v in past_key_value if v is not None) if breakmodel and past_key_value is not None and i >= ram_blocks and len(past_key_value) and past_key_value[0].device.index != device else past_key_value, past_key_value=tuple(v.to(device) for v in past_key_value if v is not None) if breakmodel and past_key_value is not None and i >= ram_blocks and len(past_key_value) and past_key_value[0].device.index != device else past_key_value,