diff --git a/modeling/inference_models/basic_hf/class.py b/modeling/inference_models/basic_hf/class.py index 74bfcd17..afca13ee 100644 --- a/modeling/inference_models/basic_hf/class.py +++ b/modeling/inference_models/basic_hf/class.py @@ -34,6 +34,24 @@ class model_backend(InferenceModel): self.model_name = "Basic Huggingface" self.path = None + def is_valid(self, model_name, model_path, menu_path): + try: + if model_path is not None and os.path.exists(model_path): + self.model_config = AutoConfig.from_pretrained(model_path) + elif os.path.exists("models/{}".format(model_name.replace("/", "_"))): + self.model_config = AutoConfig.from_pretrained( + "models/{}".format(model_name.replace("/", "_")), + revision=utils.koboldai_vars.revision, + cache_dir="cache", + ) + else: + self.model_config = AutoConfig.from_pretrained( + model_name, revision=utils.koboldai_vars.revision, cache_dir="cache" + ) + return True + except: + return False + def get_requested_parameters( self, model_name: str, model_path: str, menu_path: str, parameters: dict = {} ):