diff --git a/.gitignore b/.gitignore index 15343072..dd922f16 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ client.settings # Ignore stories file except for test_story stories/* +settings/* !stories/sample_story.json /.project *.bak diff --git a/aiserver.py b/aiserver.py index 3c3aa9a2..ac759d0f 100644 --- a/aiserver.py +++ b/aiserver.py @@ -117,7 +117,6 @@ class vars: actionmode = 1 adventure = False remote = False - msoverride = True # Allow models to override settings #==================================================================# # Function to get model selection at startup @@ -163,6 +162,20 @@ def gettokenids(char): keys.append(key) return keys +#==================================================================# +# Return Model Name +#==================================================================# +def getmodelname(): + if(args.configname): + modelname = args.configname + return modelname + if(vars.model == "NeoCustom" or vars.model == "GPT2Custom"): + modelname = os.path.basename(os.path.normpath(vars.custmodpth)) + return modelname + else: + modelname = vars.model + return modelname + #==================================================================# # Startup #==================================================================# @@ -177,6 +190,8 @@ parser.add_argument("--breakmodel", action='store_true', help="For models that s parser.add_argument("--breakmodel_layers", type=int, help="Specify the number of layers to commit to system RAM if --breakmodel is used") parser.add_argument("--override_delete", action='store_true', help="Deleting stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow deleting stories if using --remote and prevent deleting stories otherwise.") parser.add_argument("--override_rename", action='store_true', help="Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise.") +parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.") + args = parser.parse_args() vars.model = args.model; @@ -255,12 +270,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): # Ask for API key if InferKit was selected if(vars.model == "InferKit"): - if(not path.exists("client.settings")): + if(not path.exists("settings/" + getmodelname() + ".settings")): # If the client settings file doesn't exist, create it print("{0}Please enter your InferKit API key:{1}\n".format(colors.CYAN, colors.END)) vars.apikey = input("Key> ") # Write API key to file - file = open("client.settings", "w") + file = open("settings/" + getmodelname() + ".settings", "w") try: js = {"apikey": vars.apikey} file.write(json.dumps(js, indent=3)) @@ -268,7 +283,7 @@ if(vars.model == "InferKit"): file.close() else: # Otherwise open it up - file = open("client.settings", "r") + file = open("settings/" + getmodelname() + ".settings", "r") # Check if API key exists js = json.load(file) if("apikey" in js and js["apikey"] != ""): @@ -281,7 +296,7 @@ if(vars.model == "InferKit"): vars.apikey = input("Key> ") js["apikey"] = vars.apikey # Write API key to file - file = open("client.settings", "w") + file = open("settings/" + getmodelname() + ".settings", "w") try: file.write(json.dumps(js, indent=3)) finally: @@ -289,12 +304,12 @@ if(vars.model == "InferKit"): # Ask for API key if OpenAI was selected if(vars.model == "OAI"): - if(not path.exists("client.settings")): + if(not path.exists("settings/" + getmodelname() + ".settings")): # If the client settings file doesn't exist, create it print("{0}Please enter your OpenAI API key:{1}\n".format(colors.CYAN, colors.END)) vars.oaiapikey = input("Key> ") # Write API key to file - file = open("client.settings", "w") + file = open("settings/" + getmodelname() + ".settings", "w") try: js = {"oaiapikey": vars.oaiapikey} file.write(json.dumps(js, indent=3)) @@ -302,7 +317,7 @@ if(vars.model == "OAI"): file.close() else: # Otherwise open it up - file = open("client.settings", "r") + file = open("settings/" + getmodelname() + ".settings", "r") # Check if API key exists js = json.load(file) if("oaiapikey" in js and js["oaiapikey"] != ""): @@ -315,7 +330,7 @@ if(vars.model == "OAI"): vars.oaiapikey = input("Key> ") js["oaiapikey"] = vars.oaiapikey # Write API key to file - file = open("client.settings", "w") + file = open("settings/" + getmodelname() + ".settings", "w") try: file.write(json.dumps(js, indent=3)) finally: @@ -818,10 +833,11 @@ def savesettings(): js["widepth"] = vars.widepth js["useprompt"] = vars.useprompt js["adventure"] = vars.adventure - js["msoverride"] = vars.msoverride # Write it - file = open("client.settings", "w") + if not os.path.exists('settings'): + os.mkdir('settings') + file = open("settings/" + getmodelname() + ".settings", "w") try: file.write(json.dumps(js, indent=3)) finally: @@ -831,9 +847,9 @@ def savesettings(): # Read settings from client file JSON and send to vars #==================================================================# def loadsettings(): - if(path.exists("client.settings")): + if(path.exists("settings/" + getmodelname() + ".settings")): # Read file contents into JSON object - file = open("client.settings", "r") + file = open("settings/" + getmodelname() + ".settings", "r") js = json.load(file) # Copy file contents to vars @@ -867,8 +883,6 @@ def loadsettings(): vars.useprompt = js["useprompt"] if("adventure" in js): vars.adventure = js["adventure"] - if("msoverride" in js): - vars.msoverride = js["msoverride"] file.close() @@ -881,25 +895,18 @@ def loadmodelsettings(): js = json.load(model_config) if("badwordsids" in js): vars.badwordsids = js["badwordsids"] - if vars.msoverride: - if("temp" in js): - print("temp forced by model") - vars.temp = js["temp"] - if("top_p" in js): - print("top_p forced by model") - vars.top_p = js["top_p"] - if("top_k" in js): - print("top_k forced by model") - vars.top_k = js["top_k"] - if("tfs" in js): - print("tfs forced by model") - vars.tfs = js["tfs"] - if("rep_pen" in js): - print("Repetition Penalty forced by model") - vars.rep_pen = js["rep_pen"] - if("adventure" in js): - print("Adventure mode enabled/disabled by model") - vars.adventure = js["adventure"] + if("temp" in js): + vars.temp = js["temp"] + if("top_p" in js): + vars.top_p = js["top_p"] + if("top_k" in js): + vars.top_k = js["top_k"] + if("tfs" in js): + vars.tfs = js["tfs"] + if("rep_pen" in js): + vars.rep_pen = js["rep_pen"] + if("adventure" in js): + vars.adventure = js["adventure"] model_config.close() #==================================================================# @@ -2337,8 +2344,9 @@ def randomGameRequest(topic): if __name__ == "__main__": # Load settings from client.settings - loadsettings() loadmodelsettings() + loadsettings() + # Start Flask/SocketIO (Blocking, so this must be last method!) #socketio.run(app, host='0.0.0.0', port=5000)