diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index 9e7710fc..108887f6 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -57,6 +57,10 @@ class BreakmodelConfig: return "cpu" elif torch.cuda.device_count() <= 0: return "cpu" + + for device_index, blocks in enumerate(self.gpu_blocks): + if blocks: + return device_index return 0 def get_device_map(self, model: nn.Module) -> dict: