diff --git a/breakmodel.py b/breakmodel.py index 73e40222..5724d4e2 100644 --- a/breakmodel.py +++ b/breakmodel.py @@ -387,7 +387,7 @@ def new_forward( all_hidden_states = () if output_hidden_states else None if breakmodel and ram_blocks: - copystream = torch.cuda.Stream(device=0,priority = -1) + copystream = torch.cuda.Stream(device=primary_device, priority=-1) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):