mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
CPU fixes
This commit is contained in:
@@ -121,6 +121,8 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
return torch.float32
|
return torch.float32
|
||||||
elif utils.args.cpu:
|
elif utils.args.cpu:
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
elif not self.usegpu and not self.breakmodel:
|
||||||
|
return torch.float32
|
||||||
return torch.float16
|
return torch.float16
|
||||||
|
|
||||||
def _apply_warpers(
|
def _apply_warpers(
|
||||||
@@ -268,7 +270,9 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
|
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
|
||||||
else:
|
else:
|
||||||
gen_in = prompt_tokens
|
gen_in = prompt_tokens
|
||||||
|
if not self.usegpu and not self.breakmodel:
|
||||||
|
gen_in = gen_in.to("cpu")
|
||||||
|
else:
|
||||||
device = utils.get_auxilary_device()
|
device = utils.get_auxilary_device()
|
||||||
gen_in = gen_in.to(device)
|
gen_in = gen_in.to(device)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user