diff --git a/breakmodel.py b/breakmodel.py index 262d39bd..eb49e669 100644 --- a/breakmodel.py +++ b/breakmodel.py @@ -816,7 +816,7 @@ def new_forward_opt( if breakmodel: device = primary_device if i < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, i - ram_blocks) layer_outputs = decoder_layer( - hidden_states, + hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states, attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask, layer_head_mask=((head_mask[idx].to(device) if breakmodel and head_mask[idx] is not None else head_mask[idx]) if head_mask is not None else None), past_key_value=tuple(v.to(device) for v in past_key_value if v is not None) if breakmodel and past_key_value is not None and i >= ram_blocks and len(past_key_value) and past_key_value[0].device.index != device else past_key_value,