CPU fixes

This commit is contained in:
Henk
2023-07-02 21:50:23 +02:00
parent 1da4580e8b
commit 81e72329af

View File

@@ -121,6 +121,8 @@ class HFTorchInferenceModel(HFInferenceModel):
return torch.float32
elif utils.args.cpu:
return torch.float32
elif not self.usegpu and not self.breakmodel:
return torch.float32
return torch.float16
def _apply_warpers(
@@ -268,9 +270,11 @@ class HFTorchInferenceModel(HFInferenceModel):
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
else:
gen_in = prompt_tokens
device = utils.get_auxilary_device()
gen_in = gen_in.to(device)
if not self.usegpu and not self.breakmodel:
gen_in = gen_in.to("cpu")
else:
device = utils.get_auxilary_device()
gen_in = gen_in.to(device)
additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else []