From 6f7e6422ef5fd038a561ab28b9c49f8d6289001f Mon Sep 17 00:00:00 2001 From: somebody Date: Mon, 3 Jul 2023 19:04:48 -0500 Subject: [PATCH] Actually get correct primary device --- modeling/inference_models/hf_torch.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index 9e7710fc..108887f6 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -57,6 +57,10 @@ class BreakmodelConfig: return "cpu" elif torch.cuda.device_count() <= 0: return "cpu" + + for device_index, blocks in enumerate(self.gpu_blocks): + if blocks: + return device_index return 0 def get_device_map(self, model: nn.Module) -> dict: