From be719a7e5e90b231f5ce6078ab3ad9c1c992d4d2 Mon Sep 17 00:00:00 2001 From: ebolam Date: Tue, 27 Sep 2022 19:02:37 -0400 Subject: [PATCH] Fix for loading models that don't support breakmodel (GPU/CPU support in UI) --- aiserver.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/aiserver.py b/aiserver.py index a80451f7..8c74e631 100644 --- a/aiserver.py +++ b/aiserver.py @@ -2230,7 +2230,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal loadsettings() logger.init("GPU support", status="Searching") vars.hascuda = torch.cuda.is_available() - vars.bmsupported = (utils.HAS_ACCELERATE or vars.model_type in ("gpt_neo", "gptj", "xglm", "opt")) and not vars.nobreakmodel + vars.bmsupported = ((utils.HAS_ACCELERATE and vars.model_type != 'gpt2') or vars.model_type in ("gpt_neo", "gptj", "xglm", "opt")) and not vars.nobreakmodel if(args.breakmodel is not None and args.breakmodel): logger.warning("--breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --breakmodel_gpulayers is used (see --help for details).") if(args.breakmodel_layers is not None): @@ -2256,7 +2256,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal vars.breakmodel = True else: vars.breakmodel = False - vars.usegpu = True + vars.usegpu = use_gpu # Ask for API key if InferKit was selected @@ -2430,9 +2430,14 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal yield False # If custom GPT2 model was chosen - if(vars.model == "GPT2Custom"): + if(vars.model_type == "gpt2"): vars.lazy_load = False - model_config = open(vars.custmodpth + "/config.json", "r") + if os.path.exists(vars.custmodpth): + model_config = open(vars.custmodpth + "/config.json", "r") + elif os.path.exists(os.path.join("models/", vars.custmodpth)): + config_path = os.path.join("models/", vars.custmodpth) + config_path = os.path.join(config_path, "config.json").replace("\\", "//") + model_config = open(config_path, "r") js = json.load(model_config) with(maybe_use_float16()): try: