diff --git a/modeling/inference_models/generic_hf_torch/class.py b/modeling/inference_models/generic_hf_torch/class.py index a0b7b4cb..b56a7c45 100644 --- a/modeling/inference_models/generic_hf_torch/class.py +++ b/modeling/inference_models/generic_hf_torch/class.py @@ -9,6 +9,8 @@ from typing import Union from transformers import GPTNeoForCausalLM, GPT2LMHeadModel from hf_bleeding_edge import AutoModelForCausalLM +from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF_WEIGHTS_NAME, FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME + import utils import modeling.lazy_loader as lazy_loader import koboldai_settings @@ -27,6 +29,19 @@ model_backend_name = "Huggingface" class model_backend(HFTorchInferenceModel): + def is_valid(self, model_name, model_path, menu_path): + base_is_valid = super().is_valid(model_name, model_path, menu_path) + path = False + gen_path = "models/{}".format(model_name.replace('/', '_')) + if model_path is not None and os.path.exists(model_path): + path = model_path + elif os.path.exists(gen_path): + path = gen_path + + fnames = [WEIGHTS_NAME, WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF_WEIGHTS_NAME, FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME] + + return base_is_valid and any(os.path.exists(os.path.join(path, fname)) for fname in fnames) + def _initialize_model(self): return