Model fix

This commit is contained in:
somebody
2023-04-02 15:47:52 -05:00
parent 9d70646e4d
commit 77f0797b1a

View File

@@ -315,7 +315,11 @@ class HFTorchInferenceModel(HFInferenceModel):
)
if value.dtype is not target_dtype:
accelerate.utils.set_module_tensor_to_device(
self.model, key, target_dtype
self.model,
tensor_name=key,
device=torch.device(value.device),
value=value,
dtype=target_dtype,
)
disk_blocks = breakmodel.disk_blocks