diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index 7719f022..2c035e62 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -121,6 +121,8 @@ class HFTorchInferenceModel(HFInferenceModel): return torch.float32 elif utils.args.cpu: return torch.float32 + elif not self.usegpu and not self.breakmodel: + return torch.float32 return torch.float16 def _apply_warpers( @@ -268,9 +270,11 @@ class HFTorchInferenceModel(HFInferenceModel): gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None] else: gen_in = prompt_tokens - - device = utils.get_auxilary_device() - gen_in = gen_in.to(device) + if not self.usegpu and not self.breakmodel: + gen_in = gen_in.to("cpu") + else: + device = utils.get_auxilary_device() + gen_in = gen_in.to(device) additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else []