diff --git a/modeling/inference_models/hf_mtj/class.py b/modeling/inference_models/hf_mtj/class.py index 1b6b2cb8..c0f70843 100644 --- a/modeling/inference_models/hf_mtj/class.py +++ b/modeling/inference_models/hf_mtj/class.py @@ -29,15 +29,21 @@ class model_backend(HFInferenceModel): #model_name: str, ) -> None: super().__init__() - self.hf_torch = False - self.model_config = None - self.capabilties = ModelCapabilities( - embedding_manipulation=False, - post_token_hooks=False, - stopper_hooks=False, - post_token_probs=False, - uses_tpu=True, - ) + import importlib + dependency_exists = importlib.util.find_spec("jax") + if dependency_exists: + self.hf_torch = False + self.model_config = None + self.capabilties = ModelCapabilities( + embedding_manipulation=False, + post_token_hooks=False, + stopper_hooks=False, + post_token_probs=False, + uses_tpu=True, + ) + else: + logger.debug("Jax is not installed, hiding TPU backend") + self.disable = True def is_valid(self, model_name, model_path, menu_path): # This file shouldn't be imported unless using the TPU