mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Probably fix f32
This commit is contained in:
@@ -107,6 +107,13 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
|
||||
return ret
|
||||
|
||||
def _get_target_dtype(self) -> Union[torch.float16, torch.float32]:
|
||||
if self.breakmodel_config.primary_device == "cpu":
|
||||
return torch.float32
|
||||
elif utils.args.cpu:
|
||||
return torch.float32
|
||||
return torch.float16
|
||||
|
||||
def _apply_warpers(
|
||||
self, scores: torch.Tensor, input_ids: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
@@ -317,7 +324,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
location,
|
||||
offload_folder="accelerate-disk-cache",
|
||||
torch_dtype=torch.float16,
|
||||
torch_dtype=self._get_target_dtype(),
|
||||
**tf_kwargs,
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user