Model loading fix

This commit is contained in:
Henk
2023-04-02 00:29:56 +02:00
parent 943d0fe68a
commit 4a8b099888

View File

@@ -1123,7 +1123,7 @@ def move_model_to_devices(model):
for key, value in model.state_dict().items():
target_dtype = torch.float32 if breakmodel.primary_device == "cpu" else torch.float16
if(value.dtype is not target_dtype):
accelerate.utils.set_module_tensor_to_device(model, key, target_dtype)
accelerate.utils.set_module_tensor_to_device(model, key, torch.device(value.device), value, target_dtype)
disk_blocks = breakmodel.disk_blocks
gpu_blocks = breakmodel.gpu_blocks
ram_blocks = len(utils.layers_module_names) - sum(gpu_blocks)