From afa8766ea68d8b08218edfab36526b75562609a6 Mon Sep 17 00:00:00 2001 From: onesome Date: Fri, 14 Jul 2023 18:01:18 -0500 Subject: [PATCH] Add is_valid --- modeling/inference_models/basic_hf/class.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/modeling/inference_models/basic_hf/class.py b/modeling/inference_models/basic_hf/class.py index 74bfcd17..afca13ee 100644 --- a/modeling/inference_models/basic_hf/class.py +++ b/modeling/inference_models/basic_hf/class.py @@ -34,6 +34,24 @@ class model_backend(InferenceModel): self.model_name = "Basic Huggingface" self.path = None + def is_valid(self, model_name, model_path, menu_path): + try: + if model_path is not None and os.path.exists(model_path): + self.model_config = AutoConfig.from_pretrained(model_path) + elif os.path.exists("models/{}".format(model_name.replace("/", "_"))): + self.model_config = AutoConfig.from_pretrained( + "models/{}".format(model_name.replace("/", "_")), + revision=utils.koboldai_vars.revision, + cache_dir="cache", + ) + else: + self.model_config = AutoConfig.from_pretrained( + model_name, revision=utils.koboldai_vars.revision, cache_dir="cache" + ) + return True + except: + return False + def get_requested_parameters( self, model_name: str, model_path: str, menu_path: str, parameters: dict = {} ):