mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Probably fix f32
This commit is contained in:
@@ -107,6 +107,13 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
def _get_target_dtype(self) -> Union[torch.float16, torch.float32]:
|
||||||
|
if self.breakmodel_config.primary_device == "cpu":
|
||||||
|
return torch.float32
|
||||||
|
elif utils.args.cpu:
|
||||||
|
return torch.float32
|
||||||
|
return torch.float16
|
||||||
|
|
||||||
def _apply_warpers(
|
def _apply_warpers(
|
||||||
self, scores: torch.Tensor, input_ids: torch.Tensor
|
self, scores: torch.Tensor, input_ids: torch.Tensor
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -317,7 +324,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
location,
|
location,
|
||||||
offload_folder="accelerate-disk-cache",
|
offload_folder="accelerate-disk-cache",
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=self._get_target_dtype(),
|
||||||
**tf_kwargs,
|
**tf_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user