Silently disable MTJ when Jax is not installed

This commit is contained in:
Henk
2023-06-21 17:08:45 +02:00
parent d46663ac0d
commit bbecdaeedb

View File

@@ -29,15 +29,21 @@ class model_backend(HFInferenceModel):
#model_name: str,
) -> None:
super().__init__()
self.hf_torch = False
self.model_config = None
self.capabilties = ModelCapabilities(
embedding_manipulation=False,
post_token_hooks=False,
stopper_hooks=False,
post_token_probs=False,
uses_tpu=True,
)
import importlib
dependency_exists = importlib.util.find_spec("jax")
if dependency_exists:
self.hf_torch = False
self.model_config = None
self.capabilties = ModelCapabilities(
embedding_manipulation=False,
post_token_hooks=False,
stopper_hooks=False,
post_token_probs=False,
uses_tpu=True,
)
else:
logger.debug("Jax is not installed, hiding TPU backend")
self.disable = True
def is_valid(self, model_name, model_path, menu_path):
# This file shouldn't be imported unless using the TPU