mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Don't show HF support if no HF model files are found
This commit is contained in:
@@ -9,6 +9,8 @@ from typing import Union
|
|||||||
from transformers import GPTNeoForCausalLM, GPT2LMHeadModel
|
from transformers import GPTNeoForCausalLM, GPT2LMHeadModel
|
||||||
from hf_bleeding_edge import AutoModelForCausalLM
|
from hf_bleeding_edge import AutoModelForCausalLM
|
||||||
|
|
||||||
|
from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF_WEIGHTS_NAME, FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
import modeling.lazy_loader as lazy_loader
|
import modeling.lazy_loader as lazy_loader
|
||||||
import koboldai_settings
|
import koboldai_settings
|
||||||
@@ -27,6 +29,19 @@ model_backend_name = "Huggingface"
|
|||||||
|
|
||||||
class model_backend(HFTorchInferenceModel):
|
class model_backend(HFTorchInferenceModel):
|
||||||
|
|
||||||
|
def is_valid(self, model_name, model_path, menu_path):
|
||||||
|
base_is_valid = super().is_valid(model_name, model_path, menu_path)
|
||||||
|
path = False
|
||||||
|
gen_path = "models/{}".format(model_name.replace('/', '_'))
|
||||||
|
if model_path is not None and os.path.exists(model_path):
|
||||||
|
path = model_path
|
||||||
|
elif os.path.exists(gen_path):
|
||||||
|
path = gen_path
|
||||||
|
|
||||||
|
fnames = [WEIGHTS_NAME, WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF_WEIGHTS_NAME, FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME]
|
||||||
|
|
||||||
|
return base_is_valid and any(os.path.exists(os.path.join(path, fname)) for fname in fnames)
|
||||||
|
|
||||||
def _initialize_model(self):
|
def _initialize_model(self):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user