mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-02 10:36:47 +01:00
Fix the model selection GUI when there is no internet connection
This commit is contained in:
parent
0d4bffe8f8
commit
55f45c4912
20
aiserver.py
20
aiserver.py
@ -1474,22 +1474,22 @@ def get_model_info(model, directory=""):
|
||||
|
||||
def get_layer_count(model, directory=""):
|
||||
if(model not in ["InferKit", "Colab", "API", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||
if(vars.model == "GPT2Custom"):
|
||||
model_config = open(vars.custmodpth + "/config.json", "r")
|
||||
if(model == "GPT2Custom"):
|
||||
with open(os.path.join(directory, "config.json"), "r") as f:
|
||||
model_config = json.load(f)
|
||||
# Get the model_type from the config or assume a model type if it isn't present
|
||||
else:
|
||||
if(directory):
|
||||
model = directory
|
||||
from transformers import AutoConfig
|
||||
if directory == "":
|
||||
model_config = AutoConfig.from_pretrained(model, revision=vars.revision, cache_dir="cache")
|
||||
if(os.path.isdir(model.replace('/', '_'))):
|
||||
model_config = AutoConfig.from_pretrained(model.replace('/', '_'), revision=vars.revision, cache_dir="cache")
|
||||
elif(os.path.isdir("models/{}".format(model.replace('/', '_')))):
|
||||
model_config = AutoConfig.from_pretrained("models/{}".format(model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
|
||||
elif(os.path.isdir(directory)):
|
||||
model_config = AutoConfig.from_pretrained(directory, revision=vars.revision, cache_dir="cache")
|
||||
elif(os.path.isdir(vars.custmodpth.replace('/', '_'))):
|
||||
model_config = AutoConfig.from_pretrained(vars.custmodpth.replace('/', '_'), revision=vars.revision, cache_dir="cache")
|
||||
else:
|
||||
model_config = AutoConfig.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
|
||||
|
||||
|
||||
|
||||
model_config = AutoConfig.from_pretrained(model, revision=vars.revision, cache_dir="cache")
|
||||
return utils.num_layers(model_config)
|
||||
else:
|
||||
return None
|
||||
|
2
utils.py
2
utils.py
@ -167,7 +167,7 @@ def decodenewlines(txt):
|
||||
# Returns number of layers given an HF model config
|
||||
#==================================================================#
|
||||
def num_layers(config):
|
||||
return config.num_layers if hasattr(config, "num_layers") else config.n_layer if hasattr(config, "n_layer") else config.num_hidden_layers if hasattr(config, 'num_hidden_layers') else None
|
||||
return config["n_layer"] if isinstance(config, dict) else config.num_layers if hasattr(config, "num_layers") else config.n_layer if hasattr(config, "n_layer") else config.num_hidden_layers if hasattr(config, 'num_hidden_layers') else None
|
||||
|
||||
#==================================================================#
|
||||
# Downloads huggingface checkpoints using aria2c if possible
|
||||
|
Loading…
x
Reference in New Issue
Block a user