Merge pull request #221 from ebolam/united

Fix for loading models that don't support breakmodel (GPU/CPU support…
This commit is contained in:
henk717 2022-09-28 01:32:08 +02:00 committed by GitHub
commit c935d8646a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 4 deletions

View File

@ -2232,7 +2232,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
loadsettings()
logger.init("GPU support", status="Searching")
vars.hascuda = torch.cuda.is_available() and not args.cpu
vars.bmsupported = (utils.HAS_ACCELERATE or vars.model_type in ("gpt_neo", "gptj", "xglm", "opt")) and not vars.nobreakmodel and not args.cpu
vars.bmsupported = ((utils.HAS_ACCELERATE and vars.model_type != 'gpt2') or vars.model_type in ("gpt_neo", "gptj", "xglm", "opt")) and not vars.nobreakmodel
if(args.breakmodel is not None and args.breakmodel):
logger.warning("--breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --breakmodel_gpulayers is used (see --help for details).")
if(args.breakmodel_layers is not None):
@ -2258,7 +2258,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
vars.breakmodel = True
else:
vars.breakmodel = False
vars.usegpu = True
vars.usegpu = use_gpu
# Ask for API key if InferKit was selected
@ -2432,9 +2432,14 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
yield False
# If custom GPT2 model was chosen
if(vars.model == "GPT2Custom"):
if(vars.model_type == "gpt2"):
vars.lazy_load = False
model_config = open(vars.custmodpth + "/config.json", "r")
if os.path.exists(vars.custmodpth):
model_config = open(vars.custmodpth + "/config.json", "r")
elif os.path.exists(os.path.join("models/", vars.custmodpth)):
config_path = os.path.join("models/", vars.custmodpth)
config_path = os.path.join(config_path, "config.json").replace("\\", "//")
model_config = open(config_path, "r")
js = json.load(model_config)
with(maybe_use_float16()):
try: