mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-09 00:08:53 +01:00
Fix the model selection GUI when there is no internet connection
This commit is contained in:
parent
0d4bffe8f8
commit
55f45c4912
20
aiserver.py
20
aiserver.py
@ -1474,22 +1474,22 @@ def get_model_info(model, directory=""):
|
|||||||
|
|
||||||
def get_layer_count(model, directory=""):
|
def get_layer_count(model, directory=""):
|
||||||
if(model not in ["InferKit", "Colab", "API", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
if(model not in ["InferKit", "Colab", "API", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||||
if(vars.model == "GPT2Custom"):
|
if(model == "GPT2Custom"):
|
||||||
model_config = open(vars.custmodpth + "/config.json", "r")
|
with open(os.path.join(directory, "config.json"), "r") as f:
|
||||||
|
model_config = json.load(f)
|
||||||
# Get the model_type from the config or assume a model type if it isn't present
|
# Get the model_type from the config or assume a model type if it isn't present
|
||||||
else:
|
else:
|
||||||
|
if(directory):
|
||||||
|
model = directory
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
if directory == "":
|
if(os.path.isdir(model.replace('/', '_'))):
|
||||||
model_config = AutoConfig.from_pretrained(model, revision=vars.revision, cache_dir="cache")
|
model_config = AutoConfig.from_pretrained(model.replace('/', '_'), revision=vars.revision, cache_dir="cache")
|
||||||
|
elif(os.path.isdir("models/{}".format(model.replace('/', '_')))):
|
||||||
|
model_config = AutoConfig.from_pretrained("models/{}".format(model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
|
||||||
elif(os.path.isdir(directory)):
|
elif(os.path.isdir(directory)):
|
||||||
model_config = AutoConfig.from_pretrained(directory, revision=vars.revision, cache_dir="cache")
|
model_config = AutoConfig.from_pretrained(directory, revision=vars.revision, cache_dir="cache")
|
||||||
elif(os.path.isdir(vars.custmodpth.replace('/', '_'))):
|
|
||||||
model_config = AutoConfig.from_pretrained(vars.custmodpth.replace('/', '_'), revision=vars.revision, cache_dir="cache")
|
|
||||||
else:
|
else:
|
||||||
model_config = AutoConfig.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
|
model_config = AutoConfig.from_pretrained(model, revision=vars.revision, cache_dir="cache")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return utils.num_layers(model_config)
|
return utils.num_layers(model_config)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
2
utils.py
2
utils.py
@ -167,7 +167,7 @@ def decodenewlines(txt):
|
|||||||
# Returns number of layers given an HF model config
|
# Returns number of layers given an HF model config
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
def num_layers(config):
|
def num_layers(config):
|
||||||
return config.num_layers if hasattr(config, "num_layers") else config.n_layer if hasattr(config, "n_layer") else config.num_hidden_layers if hasattr(config, 'num_hidden_layers') else None
|
return config["n_layer"] if isinstance(config, dict) else config.num_layers if hasattr(config, "num_layers") else config.n_layer if hasattr(config, "n_layer") else config.num_hidden_layers if hasattr(config, 'num_hidden_layers') else None
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Downloads huggingface checkpoints using aria2c if possible
|
# Downloads huggingface checkpoints using aria2c if possible
|
||||||
|
Loading…
x
Reference in New Issue
Block a user