This commit is contained in:
Henk
2023-07-24 02:05:07 +02:00
parent 9fc9cb92f7
commit 30495cf8d8
2 changed files with 9 additions and 3 deletions

View File

@@ -126,8 +126,13 @@ class HFTorchInferenceModel(HFInferenceModel):
return ret
def get_auxilary_device(self) -> Union[str, int, torch.device]:
return self.breakmodel_config.primary_device
if self.breakmodel:
return self.breakmodel_config.primary_device
if self.usegpu:
return "cuda:0"
else:
return "cpu"
def _get_target_dtype(self) -> Union[torch.float16, torch.float32]:
if self.breakmodel_config.primary_device == "cpu":
return torch.float32