diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index 6bcd88cd..514a1e5b 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -50,8 +50,12 @@ class BreakmodelConfig: self.primary_device = 0 if torch.cuda.device_count() > 0 else "cpu" def get_device_map(self, model: nn.Module) -> dict: - # HACK - if utils.args.cpu: + if ( + # Explicitly CPU-only + utils.args.cpu + # No blocks are on GPU + or not sum(self.gpu_blocks) + ): self.primary_device = "cpu" ram_blocks = len(utils.layers_module_names) - sum(self.gpu_blocks)