From bbecdaeedb4ef66ecc9e2dd3fc7e6824fe36b8a9 Mon Sep 17 00:00:00 2001 From: Henk Date: Wed, 21 Jun 2023 17:08:45 +0200 Subject: [PATCH] Silently disable MTJ when Jax is not installed --- modeling/inference_models/hf_mtj/class.py | 24 ++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/modeling/inference_models/hf_mtj/class.py b/modeling/inference_models/hf_mtj/class.py index 1b6b2cb8..c0f70843 100644 --- a/modeling/inference_models/hf_mtj/class.py +++ b/modeling/inference_models/hf_mtj/class.py @@ -29,15 +29,21 @@ class model_backend(HFInferenceModel): #model_name: str, ) -> None: super().__init__() - self.hf_torch = False - self.model_config = None - self.capabilties = ModelCapabilities( - embedding_manipulation=False, - post_token_hooks=False, - stopper_hooks=False, - post_token_probs=False, - uses_tpu=True, - ) + import importlib + dependency_exists = importlib.util.find_spec("jax") + if dependency_exists: + self.hf_torch = False + self.model_config = None + self.capabilties = ModelCapabilities( + embedding_manipulation=False, + post_token_hooks=False, + stopper_hooks=False, + post_token_probs=False, + uses_tpu=True, + ) + else: + logger.debug("Jax is not installed, hiding TPU backend") + self.disable = True def is_valid(self, model_name, model_path, menu_path): # This file shouldn't be imported unless using the TPU