From 55f45c49127d43a3cfffb13010f969cddbc0e664 Mon Sep 17 00:00:00 2001 From: vfbd Date: Mon, 22 Aug 2022 14:45:02 -0400 Subject: [PATCH] Fix the model selection GUI when there is no internet connection --- aiserver.py | 20 ++++++++++---------- utils.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/aiserver.py b/aiserver.py index ef785313..642ced7d 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1474,22 +1474,22 @@ def get_model_info(model, directory=""): def get_layer_count(model, directory=""): if(model not in ["InferKit", "Colab", "API", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]): - if(vars.model == "GPT2Custom"): - model_config = open(vars.custmodpth + "/config.json", "r") + if(model == "GPT2Custom"): + with open(os.path.join(directory, "config.json"), "r") as f: + model_config = json.load(f) # Get the model_type from the config or assume a model type if it isn't present else: + if(directory): + model = directory from transformers import AutoConfig - if directory == "": - model_config = AutoConfig.from_pretrained(model, revision=vars.revision, cache_dir="cache") + if(os.path.isdir(model.replace('/', '_'))): + model_config = AutoConfig.from_pretrained(model.replace('/', '_'), revision=vars.revision, cache_dir="cache") + elif(os.path.isdir("models/{}".format(model.replace('/', '_')))): + model_config = AutoConfig.from_pretrained("models/{}".format(model.replace('/', '_')), revision=vars.revision, cache_dir="cache") elif(os.path.isdir(directory)): model_config = AutoConfig.from_pretrained(directory, revision=vars.revision, cache_dir="cache") - elif(os.path.isdir(vars.custmodpth.replace('/', '_'))): - model_config = AutoConfig.from_pretrained(vars.custmodpth.replace('/', '_'), revision=vars.revision, cache_dir="cache") else: - model_config = AutoConfig.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache") - - - + model_config = AutoConfig.from_pretrained(model, revision=vars.revision, cache_dir="cache") return utils.num_layers(model_config) else: return None diff --git a/utils.py b/utils.py index 7fd82072..44c1129a 100644 --- a/utils.py +++ b/utils.py @@ -167,7 +167,7 @@ def decodenewlines(txt): # Returns number of layers given an HF model config #==================================================================# def num_layers(config): - return config.num_layers if hasattr(config, "num_layers") else config.n_layer if hasattr(config, "n_layer") else config.num_hidden_layers if hasattr(config, 'num_hidden_layers') else None + return config["n_layer"] if isinstance(config, dict) else config.num_layers if hasattr(config, "num_layers") else config.n_layer if hasattr(config, "n_layer") else config.num_hidden_layers if hasattr(config, 'num_hidden_layers') else None #==================================================================# # Downloads huggingface checkpoints using aria2c if possible