From c45ba497c9521bc97332bbf23760a6b06739b30d Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Wed, 23 Feb 2022 20:39:11 -0500 Subject: [PATCH] Load settings earlier to avoid TPU badwords issues --- aiserver.py | 188 ++++++++++++++++++++++++++-------------------------- 1 file changed, 95 insertions(+), 93 deletions(-) diff --git a/aiserver.py b/aiserver.py index ab26b5cd..48a362b7 100644 --- a/aiserver.py +++ b/aiserver.py @@ -461,6 +461,95 @@ def loadmodelsettings(): if(not vars.gamestarted): vars.authornotetemplate = vars.setauthornotetemplate +#==================================================================# +# Read settings from client file JSON and send to vars +#==================================================================# +def loadsettings(): + if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")): + # Read file contents into JSON object + file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r") + js = json.load(file) + + # Copy file contents to vars + if("apikey" in js): + vars.apikey = js["apikey"] + if("andepth" in js): + vars.andepth = js["andepth"] + if("temp" in js): + vars.temp = js["temp"] + if("top_p" in js): + vars.top_p = js["top_p"] + if("top_k" in js): + vars.top_k = js["top_k"] + if("tfs" in js): + vars.tfs = js["tfs"] + if("rep_pen" in js): + vars.rep_pen = js["rep_pen"] + if("rep_pen_slope" in js): + vars.rep_pen_slope = js["rep_pen_slope"] + if("rep_pen_range" in js): + vars.rep_pen_range = js["rep_pen_range"] + if("genamt" in js): + vars.genamt = js["genamt"] + if("max_length" in js): + vars.max_length = js["max_length"] + if("ikgen" in js): + vars.ikgen = js["ikgen"] + if("formatoptns" in js): + vars.formatoptns = js["formatoptns"] + if("numseqs" in js): + vars.numseqs = js["numseqs"] + if("widepth" in js): + vars.widepth = js["widepth"] + if("useprompt" in js): + vars.useprompt = js["useprompt"] + if("adventure" in js): + vars.adventure = js["adventure"] + if("chatmode" in js): + vars.chatmode = js["chatmode"] + if("chatname" in js): + vars.chatname = js["chatname"] + if("dynamicscan" in js): + vars.dynamicscan = js["dynamicscan"] + if("nopromptgen" in js): + vars.nopromptgen = js["nopromptgen"] + if("rngpersist" in js): + vars.rngpersist = js["rngpersist"] + if("nogenmod" in js): + vars.nogenmod = js["nogenmod"] + if("autosave" in js): + vars.autosave = js["autosave"] + if("newlinemode" in js): + vars.newlinemode = js["newlinemode"] + if("welcome" in js): + vars.welcome = js["welcome"] + + if("antemplate" in js): + vars.setauthornotetemplate = js["antemplate"] + if(not vars.gamestarted): + vars.authornotetemplate = vars.setauthornotetemplate + + if("userscripts" in js): + vars.userscripts = [] + for userscript in js["userscripts"]: + if type(userscript) is not str: + continue + userscript = userscript.strip() + if len(userscript) != 0 and all(q not in userscript for q in ("..", ":")) and all(userscript[0] not in q for q in ("/", "\\")) and os.path.exists(fileops.uspath(userscript)): + vars.userscripts.append(userscript) + + if("corescript" in js and type(js["corescript"]) is str and all(q not in js["corescript"] for q in ("..", ":")) and all(js["corescript"][0] not in q for q in ("/", "\\"))): + vars.corescript = js["corescript"] + else: + vars.corescript = "default.lua" + + if(vars.allowsp and "softprompt" in js and type(js["softprompt"]) is str and all(q not in js["softprompt"] for q in ("..", ":")) and (len(js["softprompt"]) == 0 or all(js["softprompt"][0] not in q for q in ("/", "\\")))): + spRequest(js["softprompt"]) + else: + vars.spfilename = "" + + file.close() + #==================================================================# # Startup #==================================================================# @@ -573,6 +662,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)") vars.model_type = "gpt_neo" loadmodelsettings() + loadsettings() print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="") vars.hascuda = torch.cuda.is_available() vars.bmsupported = vars.model_type in ("gpt_neo", "gptj", "xglm") and not vars.nobreakmodel @@ -1191,9 +1281,11 @@ else: if(vars.model == "Colab"): from transformers import GPT2TokenizerFast tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-2.7B", cache_dir="cache/") + loadsettings() elif(vars.model == "OAI"): from transformers import GPT2TokenizerFast tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/") + loadsettings() # Load the TPU backend if requested elif(vars.model == "TPUMeshTransformerGPTJ"): print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END)) @@ -1207,10 +1299,13 @@ else: tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback tpu_mtj_backend.settings_callback = tpumtjgenerate_settings_callback loadmodelsettings() + loadsettings() tpu_mtj_backend.load_model(vars.custmodpth, **vars.modelconfig) vars.allowsp = True vars.modeldim = int(tpu_mtj_backend.params["d_model"]) tokenizer = tpu_mtj_backend.tokenizer + else: + loadsettings() # Set up Flask routes @app.route('/') @@ -2350,96 +2445,6 @@ def savesettings(): finally: file.close() -#==================================================================# -# Read settings from client file JSON and send to vars -#==================================================================# -def loadsettings(): - if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")): - # Read file contents into JSON object - file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r") - js = json.load(file) - - # Copy file contents to vars - if("apikey" in js): - vars.apikey = js["apikey"] - if("andepth" in js): - vars.andepth = js["andepth"] - if("temp" in js): - vars.temp = js["temp"] - if("top_p" in js): - vars.top_p = js["top_p"] - if("top_k" in js): - vars.top_k = js["top_k"] - if("tfs" in js): - vars.tfs = js["tfs"] - if("rep_pen" in js): - vars.rep_pen = js["rep_pen"] - if("rep_pen_slope" in js): - vars.rep_pen_slope = js["rep_pen_slope"] - if("rep_pen_range" in js): - vars.rep_pen_range = js["rep_pen_range"] - if("genamt" in js): - vars.genamt = js["genamt"] - if("max_length" in js): - vars.max_length = js["max_length"] - if("ikgen" in js): - vars.ikgen = js["ikgen"] - if("formatoptns" in js): - vars.formatoptns = js["formatoptns"] - if("numseqs" in js): - vars.numseqs = js["numseqs"] - if("widepth" in js): - vars.widepth = js["widepth"] - if("useprompt" in js): - vars.useprompt = js["useprompt"] - if("adventure" in js): - vars.adventure = js["adventure"] - if("chatmode" in js): - vars.chatmode = js["chatmode"] - if("chatname" in js): - vars.chatname = js["chatname"] - if("dynamicscan" in js): - vars.dynamicscan = js["dynamicscan"] - if("nopromptgen" in js): - vars.nopromptgen = js["nopromptgen"] - if("rngpersist" in js): - vars.rngpersist = js["rngpersist"] - if("nogenmod" in js): - vars.nogenmod = js["nogenmod"] - if("autosave" in js): - vars.autosave = js["autosave"] - if("newlinemode" in js): - vars.newlinemode = js["newlinemode"] - if("welcome" in js): - vars.welcome = js["welcome"] - - if("antemplate" in js): - vars.setauthornotetemplate = js["antemplate"] - if(not vars.gamestarted): - vars.authornotetemplate = vars.setauthornotetemplate - - if("userscripts" in js): - vars.userscripts = [] - for userscript in js["userscripts"]: - if type(userscript) is not str: - continue - userscript = userscript.strip() - if len(userscript) != 0 and all(q not in userscript for q in ("..", ":")) and all(userscript[0] not in q for q in ("/", "\\")) and os.path.exists(fileops.uspath(userscript)): - vars.userscripts.append(userscript) - - if("corescript" in js and type(js["corescript"]) is str and all(q not in js["corescript"] for q in ("..", ":")) and all(js["corescript"][0] not in q for q in ("/", "\\"))): - vars.corescript = js["corescript"] - else: - vars.corescript = "default.lua" - - if(vars.allowsp and "softprompt" in js and type(js["softprompt"]) is str and all(q not in js["softprompt"] for q in ("..", ":")) and (len(js["softprompt"]) == 0 or all(js["softprompt"][0] not in q for q in ("/", "\\")))): - spRequest(js["softprompt"]) - else: - vars.spfilename = "" - - file.close() - - #==================================================================# # Don't save settings unless 2 seconds have passed without modification #==================================================================# @@ -4886,9 +4891,6 @@ def randomGameRequest(topic, memory=""): vars.memory = memory emit('from_server', {'cmd': 'setmemory', 'data': vars.memory}, broadcast=True) -# Load desired settings from both the model and the users config file -loadsettings() - # Prevent tokenizer from taking extra time the first time it's used def __preempt_tokenizer(): if("tokenizer" not in globals()):