From 4ff1a6e9405f9f74d1cd2058acf7a3096ca6b46c Mon Sep 17 00:00:00 2001 From: henk717 Date: Thu, 23 Dec 2021 02:50:06 +0100 Subject: [PATCH] Model Type support Automatically detect or assume the model type so we do not have to hardcode all the different models people might use. This almost makes the behavior of --model identical to the NeoCustom behavior as far as the CLI is concerned. But only if the model_type is defined in the models config file. --- aiserver.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/aiserver.py b/aiserver.py index d046a437..2c554174 100644 --- a/aiserver.py +++ b/aiserver.py @@ -82,6 +82,7 @@ class vars: submission = "" # Same as above, but after applying input formatting lastctx = "" # The last context submitted to the generator model = "" # Model ID string chosen at startup + model_type = "" # Model Type (Automatically taken from the model config) noai = False # Runs the script without starting up the transformers pipeline aibusy = False # Stops submissions while the AI is working max_length = 1024 # Maximum number of tokens to submit per action @@ -388,9 +389,29 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme vars.allowsp = True # Test for GPU support import torch + + # Make model path the same as the model name to make this consistent with the other loading method if it isn't a known model type + # This code is not just a workaround for below, it is also used to make the behavior consistent with other loading methods - Henk717 + if(not vars.model in ["NeoCustom", "GPT2Custom"]): + vars.custmodpth = vars.model.replace('/', '_') + # Get the model_type from the config or assume a model type if it isn't present + from transformers import AutoConfig + try: + model_config = AutoConfig.from_pretrained(vars.custmodpth) + except ValueError as e: + vars.model_type = "not_found" + if(not vars.model_type == "not_found"): + vars.model_type = model_config.model_type + elif(vars.model == "NeoCustom"): + vars.model_type = "gpt_neo" + elif(vars.model == "GPT2Custom"): + vars.model_type = "gpt2" + else: + print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)") + vars.model_type = "gpt_neo" print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="") vars.hascuda = torch.cuda.is_available() - vars.bmsupported = vars.model in ("EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-2.7B", "EleutherAI/gpt-j-6B", "NeoCustom") + vars.bmsupported = vars.model_type in ("gpt_neo", "gptj") if(args.breakmodel is not None and args.breakmodel): print("WARNING: --breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --layers is used (see --help for details).", file=sys.stderr) if(args.breakmodel_layers is not None): @@ -864,9 +885,6 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme # feature yet if("/" not in vars.model and vars.model.lower().startswith("gpt2")): lowmem = {} - - # Make model path the same as the model name to make this consistent with the other loading method - vars.custmodpth = vars.model.replace('/', '_') # Download model from Huggingface if it does not exist, otherwise load locally @@ -874,6 +892,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme with(maybe_use_float16()): tokenizer = GPT2TokenizerFast.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/") model = AutoModelForCausalLM.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/", **lowmem) + else: print("Model does not exist locally, attempting to download from Huggingface...") tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache/") @@ -884,7 +903,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme shutil.rmtree("cache/") model.save_pretrained(vars.model.replace('/', '_')) tokenizer.save_pretrained(vars.model.replace('/', '_')) - + if(vars.hascuda): if(vars.usegpu): vars.modeldim = get_hidden_size_from_model(model)