From c40649a74e18651d54795c460ccd06dd4acb92f5 Mon Sep 17 00:00:00 2001 From: somebody Date: Wed, 21 Jun 2023 16:53:30 -0500 Subject: [PATCH] Probably fix f32 --- modeling/inference_models/hf_torch.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index 9a941cf6..8d06ff6e 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -107,6 +107,13 @@ class HFTorchInferenceModel(HFInferenceModel): return ret + def _get_target_dtype(self) -> Union[torch.float16, torch.float32]: + if self.breakmodel_config.primary_device == "cpu": + return torch.float32 + elif utils.args.cpu: + return torch.float32 + return torch.float16 + def _apply_warpers( self, scores: torch.Tensor, input_ids: torch.Tensor ) -> torch.Tensor: @@ -317,7 +324,7 @@ class HFTorchInferenceModel(HFInferenceModel): model = AutoModelForCausalLM.from_pretrained( location, offload_folder="accelerate-disk-cache", - torch_dtype=torch.float16, + torch_dtype=self._get_target_dtype(), **tf_kwargs, )