NoBreakmodel variable

Adds a Nobreakmodel var that allows Breakmodel to be turned off. This can be done trough commandline or a model config (In case Neo is used by the models config without it being a true Neo model that is compatible with breakmodel).

In addition I removed the args.colab check for breakmodel support and instead make args.colab activate nobreakmodel. And I have added a new check so that breakmodel is not even attempted if you do not specify the layers but do launch a model from the command line.
This commit is contained in:
henk717 2022-01-30 17:06:15 +01:00
parent 5b5a479f29
commit f0c0a990ea
1 changed files with 55 additions and 43 deletions

View File

@ -178,6 +178,7 @@ class vars:
useprompt = False # Whether to send the full prompt with every submit action
breakmodel = False # For GPU users, whether to use both system RAM and VRAM to conserve VRAM while offering speedup compared to CPU-only
bmsupported = False # Whether the breakmodel option is supported (GPT-Neo/GPT-J only, currently)
nobreakmodel = False # Something specifically requested Breakmodel to be disabled (For example a models config)
smandelete = False # Whether stories can be deleted from inside the browser
smanrename = False # Whether stories can be renamed from inside the browser
allowsp = False # Whether we are allowed to use soft prompts (by default enabled if we're using GPT-2, GPT-Neo or GPT-J)
@ -382,6 +383,50 @@ def device_config(model):
generator = model.generate
breakmodel.move_hidden_layers(model.transformer)
#==================================================================#
# Allow the models to override some settings
#==================================================================#
def loadmodelsettings():
try:
model_js_config = str(model_config).partition(' ')[2]
js = json.loads(model_js_config)
except Exception as e:
try:
model_js_config = open(vars.custmodpth + "/config.json", "r")
except Exception as e:
model_js_config = open(vars.custmodpth.replace('/', '_') + "/config.json", "r")
js = json.load(model_js_config)
if("badwordsids" in js):
vars.badwordsids = js["badwordsids"]
if("nobreakmodel" in js):
vars.nobreakmodel = js["nobreakmodel"]
if("temp" in js):
vars.temp = js["temp"]
if("top_p" in js):
vars.top_p = js["top_p"]
if("top_k" in js):
vars.top_k = js["top_k"]
if("tfs" in js):
vars.tfs = js["tfs"]
if("rep_pen" in js):
vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js):
vars.rep_pen_slope = js["rep_pen_slope"]
if("rep_pen_range" in js):
vars.rep_pen_range = js["rep_pen_range"]
if("adventure" in js):
vars.adventure = js["adventure"]
if("chatmode" in js):
vars.chatmode = js["chatmode"]
if("dynamicscan" in js):
vars.dynamicscan = js["dynamicscan"]
if("formatoptns" in js):
vars.formatoptns = js["formatoptns"]
if("antemplate" in js):
vars.setauthornotetemplate = js["antemplate"]
if(not vars.gamestarted):
vars.authornotetemplate = vars.setauthornotetemplate
#==================================================================#
# Startup
#==================================================================#
@ -400,6 +445,7 @@ parser.add_argument("--override_delete", action='store_true', help="Deleting sto
parser.add_argument("--override_rename", action='store_true', help="Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise.")
parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.")
parser.add_argument("--colab", action='store_true', help="Optimize for Google Colab.")
parser.add_argument("--nobreakmodel", action='store_true', help="Disables Breakmodel support completely.")
args: argparse.Namespace = None
if(os.environ.get("KOBOLDAI_ARGS") is not None):
@ -414,7 +460,11 @@ if args.colab:
args.remote = True;
args.override_rename = True;
args.override_delete = True;
args.nobreakmodel = True;
if args.nobreakmodel:
vars.nobreakmodel = True;
if args.remote:
vars.remote = True;
@ -470,13 +520,16 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
elif(vars.model_type == "not_found"):
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
vars.model_type = "gpt_neo"
loadmodelsettings()
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
vars.hascuda = torch.cuda.is_available()
vars.bmsupported = vars.model_type in ("gpt_neo", "gptj") and not args.colab
vars.bmsupported = vars.model_type in ("gpt_neo", "gptj") and not vars.nobreakmodel
if(args.breakmodel is not None and args.breakmodel):
print("WARNING: --breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --layers is used (see --help for details).", file=sys.stderr)
if(args.breakmodel_layers is not None):
print("WARNING: --breakmodel_layers is deprecated. Use --layers instead (see --help for details).", file=sys.stderr)
if(args.model and vars.bmsupported and (args.breakmodel_gpulayers is None or args.breakmodel_layers is None)):
print("WARNING: Model launched without the --breakmodel_gpulayers argument, defaulting to GPU only mode.", file=sys.stderr)
if(not vars.bmsupported and (args.breakmodel_gpulayers is not None or args.breakmodel_layers is not None)):
print("WARNING: This model does not support hybrid generation. --layers will be ignored.", file=sys.stderr)
if(vars.hascuda):
@ -2261,47 +2314,6 @@ def loadsettings():
file.close()
#==================================================================#
# Allow the models to override some settings
#==================================================================#
def loadmodelsettings():
try:
model_js_config = str(model_config).partition(' ')[2]
js = json.loads(model_js_config)
except Exception as e:
try:
model_js_config = open(vars.custmodpth + "/config.json", "r")
except Exception as e:
model_js_config = open(vars.custmodpth.replace('/', '_') + "/config.json", "r")
js = json.load(model_js_config)
if("badwordsids" in js):
vars.badwordsids = js["badwordsids"]
if("temp" in js):
vars.temp = js["temp"]
if("top_p" in js):
vars.top_p = js["top_p"]
if("top_k" in js):
vars.top_k = js["top_k"]
if("tfs" in js):
vars.tfs = js["tfs"]
if("rep_pen" in js):
vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js):
vars.rep_pen_slope = js["rep_pen_slope"]
if("rep_pen_range" in js):
vars.rep_pen_range = js["rep_pen_range"]
if("adventure" in js):
vars.adventure = js["adventure"]
if("chatmode" in js):
vars.chatmode = js["chatmode"]
if("dynamicscan" in js):
vars.dynamicscan = js["dynamicscan"]
if("formatoptns" in js):
vars.formatoptns = js["formatoptns"]
if("antemplate" in js):
vars.setauthornotetemplate = js["antemplate"]
if(not vars.gamestarted):
vars.authornotetemplate = vars.setauthornotetemplate
#==================================================================#
# Don't save settings unless 2 seconds have passed without modification