mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Update aux device to depend on primary device
This commit is contained in:
@@ -122,6 +122,9 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
|
||||
return ret
|
||||
|
||||
def get_auxilary_device(self) -> Union[str, int, torch.device]:
|
||||
return self.breakmodel_config.primary_device
|
||||
|
||||
def _get_target_dtype(self) -> Union[torch.float16, torch.float32]:
|
||||
if self.breakmodel_config.primary_device == "cpu":
|
||||
return torch.float32
|
||||
@@ -278,7 +281,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
if not self.usegpu and not self.breakmodel:
|
||||
gen_in = gen_in.to("cpu")
|
||||
else:
|
||||
device = utils.get_auxilary_device()
|
||||
device = self.get_auxilary_device()
|
||||
gen_in = gen_in.to(device)
|
||||
|
||||
additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else []
|
||||
|
Reference in New Issue
Block a user