OPT breakmodel bug fix
This commit is contained in:
parent
1200173386
commit
a051bf4397
|
@ -816,7 +816,7 @@ def new_forward_opt(
|
|||
if breakmodel:
|
||||
device = primary_device if i < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, i - ram_blocks)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states,
|
||||
attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask,
|
||||
layer_head_mask=((head_mask[idx].to(device) if breakmodel and head_mask[idx] is not None else head_mask[idx]) if head_mask is not None else None),
|
||||
past_key_value=tuple(v.to(device) for v in past_key_value if v is not None) if breakmodel and past_key_value is not None and i >= ram_blocks and len(past_key_value) and past_key_value[0].device.index != device else past_key_value,
|
||||
|
|
Loading…
Reference in New Issue