diff --git a/aiserver.py b/aiserver.py index 137d7819..3c3aa9a2 100644 --- a/aiserver.py +++ b/aiserver.py @@ -117,6 +117,7 @@ class vars: actionmode = 1 adventure = False remote = False + msoverride = True # Allow models to override settings #==================================================================# # Function to get model selection at startup @@ -478,13 +479,13 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): else: generator = pipeline('text-generation', model=vars.model) - # Suppress Author's Note by flagging square brackets + # Suppress Author's Note by flagging square brackets (Old implementation) #vocab = tokenizer.get_vocab() #vocab_keys = vocab.keys() #vars.badwords = gettokenids("[") #for key in vars.badwords: # vars.badwordsids.append([vocab[key]]) - + print("{0}OK! {1} pipeline created!{2}".format(colors.GREEN, vars.model, colors.END)) else: # If we're running Colab or OAI, we still need a tokenizer. @@ -817,6 +818,7 @@ def savesettings(): js["widepth"] = vars.widepth js["useprompt"] = vars.useprompt js["adventure"] = vars.adventure + js["msoverride"] = vars.msoverride # Write it file = open("client.settings", "w") @@ -865,9 +867,41 @@ def loadsettings(): vars.useprompt = js["useprompt"] if("adventure" in js): vars.adventure = js["adventure"] + if("msoverride" in js): + vars.msoverride = js["msoverride"] file.close() +#==================================================================# +# Allow the models to override some settings +#==================================================================# +def loadmodelsettings(): + if(path.exists(vars.custmodpth + "/config.json")): + model_config = open(vars.custmodpth + "/config.json", "r") + js = json.load(model_config) + if("badwordsids" in js): + vars.badwordsids = js["badwordsids"] + if vars.msoverride: + if("temp" in js): + print("temp forced by model") + vars.temp = js["temp"] + if("top_p" in js): + print("top_p forced by model") + vars.top_p = js["top_p"] + if("top_k" in js): + print("top_k forced by model") + vars.top_k = js["top_k"] + if("tfs" in js): + print("tfs forced by model") + vars.tfs = js["tfs"] + if("rep_pen" in js): + print("Repetition Penalty forced by model") + vars.rep_pen = js["rep_pen"] + if("adventure" in js): + print("Adventure mode enabled/disabled by model") + vars.adventure = js["adventure"] + model_config.close() + #==================================================================# # Don't save settings unless 2 seconds have passed without modification #==================================================================# @@ -2304,7 +2338,7 @@ if __name__ == "__main__": # Load settings from client.settings loadsettings() - + loadmodelsettings() # Start Flask/SocketIO (Blocking, so this must be last method!) #socketio.run(app, host='0.0.0.0', port=5000)