Update aux device to depend on primary device

This commit is contained in:
somebody
2023-07-03 19:36:31 -05:00
parent 6f7e6422ef
commit bce1a907e5
4 changed files with 16 additions and 14 deletions

View File

@@ -122,6 +122,9 @@ class HFTorchInferenceModel(HFInferenceModel):
return ret
def get_auxilary_device(self) -> Union[str, int, torch.device]:
return self.breakmodel_config.primary_device
def _get_target_dtype(self) -> Union[torch.float16, torch.float32]:
if self.breakmodel_config.primary_device == "cpu":
return torch.float32
@@ -278,7 +281,7 @@ class HFTorchInferenceModel(HFInferenceModel):
if not self.usegpu and not self.breakmodel:
gen_in = gen_in.to("cpu")
else:
device = utils.get_auxilary_device()
device = self.get_auxilary_device()
gen_in = gen_in.to(device)
additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else []