From 5ee20bd7d635bc54a748e460f00d6caf2181dd29 Mon Sep 17 00:00:00 2001 From: somebody Date: Wed, 21 Jun 2023 21:18:43 -0500 Subject: [PATCH] Fix for CPU loading --- modeling/inference_models/hf_torch.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index 6bcd88cd..514a1e5b 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -50,8 +50,12 @@ class BreakmodelConfig: self.primary_device = 0 if torch.cuda.device_count() > 0 else "cpu" def get_device_map(self, model: nn.Module) -> dict: - # HACK - if utils.args.cpu: + if ( + # Explicitly CPU-only + utils.args.cpu + # No blocks are on GPU + or not sum(self.gpu_blocks) + ): self.primary_device = "cpu" ram_blocks = len(utils.layers_module_names) - sum(self.gpu_blocks)