Load settings earlier to avoid TPU badwords issues
This commit is contained in:
parent
ac59e55d62
commit
c45ba497c9
188
aiserver.py
188
aiserver.py
|
@ -461,6 +461,95 @@ def loadmodelsettings():
|
|||
if(not vars.gamestarted):
|
||||
vars.authornotetemplate = vars.setauthornotetemplate
|
||||
|
||||
#==================================================================#
|
||||
# Read settings from client file JSON and send to vars
|
||||
#==================================================================#
|
||||
def loadsettings():
|
||||
if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
|
||||
# Read file contents into JSON object
|
||||
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r")
|
||||
js = json.load(file)
|
||||
|
||||
# Copy file contents to vars
|
||||
if("apikey" in js):
|
||||
vars.apikey = js["apikey"]
|
||||
if("andepth" in js):
|
||||
vars.andepth = js["andepth"]
|
||||
if("temp" in js):
|
||||
vars.temp = js["temp"]
|
||||
if("top_p" in js):
|
||||
vars.top_p = js["top_p"]
|
||||
if("top_k" in js):
|
||||
vars.top_k = js["top_k"]
|
||||
if("tfs" in js):
|
||||
vars.tfs = js["tfs"]
|
||||
if("rep_pen" in js):
|
||||
vars.rep_pen = js["rep_pen"]
|
||||
if("rep_pen_slope" in js):
|
||||
vars.rep_pen_slope = js["rep_pen_slope"]
|
||||
if("rep_pen_range" in js):
|
||||
vars.rep_pen_range = js["rep_pen_range"]
|
||||
if("genamt" in js):
|
||||
vars.genamt = js["genamt"]
|
||||
if("max_length" in js):
|
||||
vars.max_length = js["max_length"]
|
||||
if("ikgen" in js):
|
||||
vars.ikgen = js["ikgen"]
|
||||
if("formatoptns" in js):
|
||||
vars.formatoptns = js["formatoptns"]
|
||||
if("numseqs" in js):
|
||||
vars.numseqs = js["numseqs"]
|
||||
if("widepth" in js):
|
||||
vars.widepth = js["widepth"]
|
||||
if("useprompt" in js):
|
||||
vars.useprompt = js["useprompt"]
|
||||
if("adventure" in js):
|
||||
vars.adventure = js["adventure"]
|
||||
if("chatmode" in js):
|
||||
vars.chatmode = js["chatmode"]
|
||||
if("chatname" in js):
|
||||
vars.chatname = js["chatname"]
|
||||
if("dynamicscan" in js):
|
||||
vars.dynamicscan = js["dynamicscan"]
|
||||
if("nopromptgen" in js):
|
||||
vars.nopromptgen = js["nopromptgen"]
|
||||
if("rngpersist" in js):
|
||||
vars.rngpersist = js["rngpersist"]
|
||||
if("nogenmod" in js):
|
||||
vars.nogenmod = js["nogenmod"]
|
||||
if("autosave" in js):
|
||||
vars.autosave = js["autosave"]
|
||||
if("newlinemode" in js):
|
||||
vars.newlinemode = js["newlinemode"]
|
||||
if("welcome" in js):
|
||||
vars.welcome = js["welcome"]
|
||||
|
||||
if("antemplate" in js):
|
||||
vars.setauthornotetemplate = js["antemplate"]
|
||||
if(not vars.gamestarted):
|
||||
vars.authornotetemplate = vars.setauthornotetemplate
|
||||
|
||||
if("userscripts" in js):
|
||||
vars.userscripts = []
|
||||
for userscript in js["userscripts"]:
|
||||
if type(userscript) is not str:
|
||||
continue
|
||||
userscript = userscript.strip()
|
||||
if len(userscript) != 0 and all(q not in userscript for q in ("..", ":")) and all(userscript[0] not in q for q in ("/", "\\")) and os.path.exists(fileops.uspath(userscript)):
|
||||
vars.userscripts.append(userscript)
|
||||
|
||||
if("corescript" in js and type(js["corescript"]) is str and all(q not in js["corescript"] for q in ("..", ":")) and all(js["corescript"][0] not in q for q in ("/", "\\"))):
|
||||
vars.corescript = js["corescript"]
|
||||
else:
|
||||
vars.corescript = "default.lua"
|
||||
|
||||
if(vars.allowsp and "softprompt" in js and type(js["softprompt"]) is str and all(q not in js["softprompt"] for q in ("..", ":")) and (len(js["softprompt"]) == 0 or all(js["softprompt"][0] not in q for q in ("/", "\\")))):
|
||||
spRequest(js["softprompt"])
|
||||
else:
|
||||
vars.spfilename = ""
|
||||
|
||||
file.close()
|
||||
|
||||
#==================================================================#
|
||||
# Startup
|
||||
#==================================================================#
|
||||
|
@ -573,6 +662,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
|
||||
vars.model_type = "gpt_neo"
|
||||
loadmodelsettings()
|
||||
loadsettings()
|
||||
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
||||
vars.hascuda = torch.cuda.is_available()
|
||||
vars.bmsupported = vars.model_type in ("gpt_neo", "gptj", "xglm") and not vars.nobreakmodel
|
||||
|
@ -1191,9 +1281,11 @@ else:
|
|||
if(vars.model == "Colab"):
|
||||
from transformers import GPT2TokenizerFast
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-2.7B", cache_dir="cache/")
|
||||
loadsettings()
|
||||
elif(vars.model == "OAI"):
|
||||
from transformers import GPT2TokenizerFast
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
||||
loadsettings()
|
||||
# Load the TPU backend if requested
|
||||
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||
|
@ -1207,10 +1299,13 @@ else:
|
|||
tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback
|
||||
tpu_mtj_backend.settings_callback = tpumtjgenerate_settings_callback
|
||||
loadmodelsettings()
|
||||
loadsettings()
|
||||
tpu_mtj_backend.load_model(vars.custmodpth, **vars.modelconfig)
|
||||
vars.allowsp = True
|
||||
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
||||
tokenizer = tpu_mtj_backend.tokenizer
|
||||
else:
|
||||
loadsettings()
|
||||
|
||||
# Set up Flask routes
|
||||
@app.route('/')
|
||||
|
@ -2350,96 +2445,6 @@ def savesettings():
|
|||
finally:
|
||||
file.close()
|
||||
|
||||
#==================================================================#
|
||||
# Read settings from client file JSON and send to vars
|
||||
#==================================================================#
|
||||
def loadsettings():
|
||||
if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
|
||||
# Read file contents into JSON object
|
||||
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r")
|
||||
js = json.load(file)
|
||||
|
||||
# Copy file contents to vars
|
||||
if("apikey" in js):
|
||||
vars.apikey = js["apikey"]
|
||||
if("andepth" in js):
|
||||
vars.andepth = js["andepth"]
|
||||
if("temp" in js):
|
||||
vars.temp = js["temp"]
|
||||
if("top_p" in js):
|
||||
vars.top_p = js["top_p"]
|
||||
if("top_k" in js):
|
||||
vars.top_k = js["top_k"]
|
||||
if("tfs" in js):
|
||||
vars.tfs = js["tfs"]
|
||||
if("rep_pen" in js):
|
||||
vars.rep_pen = js["rep_pen"]
|
||||
if("rep_pen_slope" in js):
|
||||
vars.rep_pen_slope = js["rep_pen_slope"]
|
||||
if("rep_pen_range" in js):
|
||||
vars.rep_pen_range = js["rep_pen_range"]
|
||||
if("genamt" in js):
|
||||
vars.genamt = js["genamt"]
|
||||
if("max_length" in js):
|
||||
vars.max_length = js["max_length"]
|
||||
if("ikgen" in js):
|
||||
vars.ikgen = js["ikgen"]
|
||||
if("formatoptns" in js):
|
||||
vars.formatoptns = js["formatoptns"]
|
||||
if("numseqs" in js):
|
||||
vars.numseqs = js["numseqs"]
|
||||
if("widepth" in js):
|
||||
vars.widepth = js["widepth"]
|
||||
if("useprompt" in js):
|
||||
vars.useprompt = js["useprompt"]
|
||||
if("adventure" in js):
|
||||
vars.adventure = js["adventure"]
|
||||
if("chatmode" in js):
|
||||
vars.chatmode = js["chatmode"]
|
||||
if("chatname" in js):
|
||||
vars.chatname = js["chatname"]
|
||||
if("dynamicscan" in js):
|
||||
vars.dynamicscan = js["dynamicscan"]
|
||||
if("nopromptgen" in js):
|
||||
vars.nopromptgen = js["nopromptgen"]
|
||||
if("rngpersist" in js):
|
||||
vars.rngpersist = js["rngpersist"]
|
||||
if("nogenmod" in js):
|
||||
vars.nogenmod = js["nogenmod"]
|
||||
if("autosave" in js):
|
||||
vars.autosave = js["autosave"]
|
||||
if("newlinemode" in js):
|
||||
vars.newlinemode = js["newlinemode"]
|
||||
if("welcome" in js):
|
||||
vars.welcome = js["welcome"]
|
||||
|
||||
if("antemplate" in js):
|
||||
vars.setauthornotetemplate = js["antemplate"]
|
||||
if(not vars.gamestarted):
|
||||
vars.authornotetemplate = vars.setauthornotetemplate
|
||||
|
||||
if("userscripts" in js):
|
||||
vars.userscripts = []
|
||||
for userscript in js["userscripts"]:
|
||||
if type(userscript) is not str:
|
||||
continue
|
||||
userscript = userscript.strip()
|
||||
if len(userscript) != 0 and all(q not in userscript for q in ("..", ":")) and all(userscript[0] not in q for q in ("/", "\\")) and os.path.exists(fileops.uspath(userscript)):
|
||||
vars.userscripts.append(userscript)
|
||||
|
||||
if("corescript" in js and type(js["corescript"]) is str and all(q not in js["corescript"] for q in ("..", ":")) and all(js["corescript"][0] not in q for q in ("/", "\\"))):
|
||||
vars.corescript = js["corescript"]
|
||||
else:
|
||||
vars.corescript = "default.lua"
|
||||
|
||||
if(vars.allowsp and "softprompt" in js and type(js["softprompt"]) is str and all(q not in js["softprompt"] for q in ("..", ":")) and (len(js["softprompt"]) == 0 or all(js["softprompt"][0] not in q for q in ("/", "\\")))):
|
||||
spRequest(js["softprompt"])
|
||||
else:
|
||||
vars.spfilename = ""
|
||||
|
||||
file.close()
|
||||
|
||||
|
||||
#==================================================================#
|
||||
# Don't save settings unless 2 seconds have passed without modification
|
||||
#==================================================================#
|
||||
|
@ -4886,9 +4891,6 @@ def randomGameRequest(topic, memory=""):
|
|||
vars.memory = memory
|
||||
emit('from_server', {'cmd': 'setmemory', 'data': vars.memory}, broadcast=True)
|
||||
|
||||
# Load desired settings from both the model and the users config file
|
||||
loadsettings()
|
||||
|
||||
# Prevent tokenizer from taking extra time the first time it's used
|
||||
def __preempt_tokenizer():
|
||||
if("tokenizer" not in globals()):
|
||||
|
|
Loading…
Reference in New Issue