mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Don't show HF support if no HF model files are found
This commit is contained in:
@@ -9,6 +9,8 @@ from typing import Union
|
||||
from transformers import GPTNeoForCausalLM, GPT2LMHeadModel
|
||||
from hf_bleeding_edge import AutoModelForCausalLM
|
||||
|
||||
from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF_WEIGHTS_NAME, FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
import utils
|
||||
import modeling.lazy_loader as lazy_loader
|
||||
import koboldai_settings
|
||||
@@ -27,6 +29,19 @@ model_backend_name = "Huggingface"
|
||||
|
||||
class model_backend(HFTorchInferenceModel):
|
||||
|
||||
def is_valid(self, model_name, model_path, menu_path):
|
||||
base_is_valid = super().is_valid(model_name, model_path, menu_path)
|
||||
path = False
|
||||
gen_path = "models/{}".format(model_name.replace('/', '_'))
|
||||
if model_path is not None and os.path.exists(model_path):
|
||||
path = model_path
|
||||
elif os.path.exists(gen_path):
|
||||
path = gen_path
|
||||
|
||||
fnames = [WEIGHTS_NAME, WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF_WEIGHTS_NAME, FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME]
|
||||
|
||||
return base_is_valid and any(os.path.exists(os.path.join(path, fname)) for fname in fnames)
|
||||
|
||||
def _initialize_model(self):
|
||||
return
|
||||
|
||||
|
Reference in New Issue
Block a user