diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index adced4a4..8a3e750f 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -315,7 +315,11 @@ class HFTorchInferenceModel(HFInferenceModel): ) if value.dtype is not target_dtype: accelerate.utils.set_module_tensor_to_device( - self.model, key, target_dtype + self.model, + tensor_name=key, + device=torch.device(value.device), + value=value, + dtype=target_dtype, ) disk_blocks = breakmodel.disk_blocks