Actually get correct primary device

This commit is contained in:
somebody
2023-07-03 19:04:48 -05:00
parent 59c731f805
commit 6f7e6422ef

View File

@@ -57,6 +57,10 @@ class BreakmodelConfig:
return "cpu"
elif torch.cuda.device_count() <= 0:
return "cpu"
for device_index, blocks in enumerate(self.gpu_blocks):
if blocks:
return device_index
return 0
def get_device_map(self, model: nn.Module) -> dict: