Probably fix f32

This commit is contained in:
somebody
2023-06-21 16:53:30 -05:00
parent 70f113141c
commit c40649a74e

View File

@@ -107,6 +107,13 @@ class HFTorchInferenceModel(HFInferenceModel):
return ret return ret
def _get_target_dtype(self) -> Union[torch.float16, torch.float32]:
if self.breakmodel_config.primary_device == "cpu":
return torch.float32
elif utils.args.cpu:
return torch.float32
return torch.float16
def _apply_warpers( def _apply_warpers(
self, scores: torch.Tensor, input_ids: torch.Tensor self, scores: torch.Tensor, input_ids: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
@@ -317,7 +324,7 @@ class HFTorchInferenceModel(HFInferenceModel):
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
location, location,
offload_folder="accelerate-disk-cache", offload_folder="accelerate-disk-cache",
torch_dtype=torch.float16, torch_dtype=self._get_target_dtype(),
**tf_kwargs, **tf_kwargs,
) )