mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
CPU fixes
This commit is contained in:
@@ -121,6 +121,8 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
return torch.float32
|
||||
elif utils.args.cpu:
|
||||
return torch.float32
|
||||
elif not self.usegpu and not self.breakmodel:
|
||||
return torch.float32
|
||||
return torch.float16
|
||||
|
||||
def _apply_warpers(
|
||||
@@ -268,9 +270,11 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
|
||||
else:
|
||||
gen_in = prompt_tokens
|
||||
|
||||
device = utils.get_auxilary_device()
|
||||
gen_in = gen_in.to(device)
|
||||
if not self.usegpu and not self.breakmodel:
|
||||
gen_in = gen_in.to("cpu")
|
||||
else:
|
||||
device = utils.get_auxilary_device()
|
||||
gen_in = gen_in.to(device)
|
||||
|
||||
additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else []
|
||||
|
||||
|
Reference in New Issue
Block a user