From 4a8b099888e8a00bc7b66aea48ed4053166ceaf6 Mon Sep 17 00:00:00 2001 From: Henk Date: Sun, 2 Apr 2023 00:29:56 +0200 Subject: [PATCH] Model loading fix --- aiserver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiserver.py b/aiserver.py index 3368447b..7cbbc8ac 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1123,7 +1123,7 @@ def move_model_to_devices(model): for key, value in model.state_dict().items(): target_dtype = torch.float32 if breakmodel.primary_device == "cpu" else torch.float16 if(value.dtype is not target_dtype): - accelerate.utils.set_module_tensor_to_device(model, key, target_dtype) + accelerate.utils.set_module_tensor_to_device(model, key, torch.device(value.device), value, target_dtype) disk_blocks = breakmodel.disk_blocks gpu_blocks = breakmodel.gpu_blocks ram_blocks = len(utils.layers_module_names) - sum(gpu_blocks)