Fix for CPU loading

This commit is contained in:
somebody
2023-06-21 21:18:43 -05:00
parent b81f61b820
commit 5ee20bd7d6

View File

@@ -50,8 +50,12 @@ class BreakmodelConfig:
self.primary_device = 0 if torch.cuda.device_count() > 0 else "cpu"
def get_device_map(self, model: nn.Module) -> dict:
# HACK
if utils.args.cpu:
if (
# Explicitly CPU-only
utils.args.cpu
# No blocks are on GPU
or not sum(self.gpu_blocks)
):
self.primary_device = "cpu"
ram_blocks = len(utils.layers_module_names) - sum(self.gpu_blocks)