mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Silently disable MTJ when Jax is not installed
This commit is contained in:
@@ -29,15 +29,21 @@ class model_backend(HFInferenceModel):
|
||||
#model_name: str,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hf_torch = False
|
||||
self.model_config = None
|
||||
self.capabilties = ModelCapabilities(
|
||||
embedding_manipulation=False,
|
||||
post_token_hooks=False,
|
||||
stopper_hooks=False,
|
||||
post_token_probs=False,
|
||||
uses_tpu=True,
|
||||
)
|
||||
import importlib
|
||||
dependency_exists = importlib.util.find_spec("jax")
|
||||
if dependency_exists:
|
||||
self.hf_torch = False
|
||||
self.model_config = None
|
||||
self.capabilties = ModelCapabilities(
|
||||
embedding_manipulation=False,
|
||||
post_token_hooks=False,
|
||||
stopper_hooks=False,
|
||||
post_token_probs=False,
|
||||
uses_tpu=True,
|
||||
)
|
||||
else:
|
||||
logger.debug("Jax is not installed, hiding TPU backend")
|
||||
self.disable = True
|
||||
|
||||
def is_valid(self, model_name, model_path, menu_path):
|
||||
# This file shouldn't be imported unless using the TPU
|
||||
|
Reference in New Issue
Block a user