Probably fix f32

This commit is contained in:
somebody
2023-06-21 16:53:30 -05:00
parent 70f113141c
commit c40649a74e

View File

@@ -107,6 +107,13 @@ class HFTorchInferenceModel(HFInferenceModel):
return ret
def _get_target_dtype(self) -> Union[torch.float16, torch.float32]:
if self.breakmodel_config.primary_device == "cpu":
return torch.float32
elif utils.args.cpu:
return torch.float32
return torch.float16
def _apply_warpers(
self, scores: torch.Tensor, input_ids: torch.Tensor
) -> torch.Tensor:
@@ -317,7 +324,7 @@ class HFTorchInferenceModel(HFInferenceModel):
model = AutoModelForCausalLM.from_pretrained(
location,
offload_folder="accelerate-disk-cache",
torch_dtype=torch.float16,
torch_dtype=self._get_target_dtype(),
**tf_kwargs,
)