Settings per Model
Models can no longer override client settings, instead settings are now saved on a model per model basis with the settings provided by the model being the default. Users can also specify the desired configuration name as a command line parameter to avoid conflicting file names (Such as all Colabs having Colab.settings by default).
This commit is contained in:
parent
fbd07d82d7
commit
1df051a420
|
@ -3,6 +3,7 @@ client.settings
|
|||
|
||||
# Ignore stories file except for test_story
|
||||
stories/*
|
||||
settings/*
|
||||
!stories/sample_story.json
|
||||
/.project
|
||||
*.bak
|
||||
|
|
78
aiserver.py
78
aiserver.py
|
@ -117,7 +117,6 @@ class vars:
|
|||
actionmode = 1
|
||||
adventure = False
|
||||
remote = False
|
||||
msoverride = True # Allow models to override settings
|
||||
|
||||
#==================================================================#
|
||||
# Function to get model selection at startup
|
||||
|
@ -163,6 +162,20 @@ def gettokenids(char):
|
|||
keys.append(key)
|
||||
return keys
|
||||
|
||||
#==================================================================#
|
||||
# Return Model Name
|
||||
#==================================================================#
|
||||
def getmodelname():
|
||||
if(args.configname):
|
||||
modelname = args.configname
|
||||
return modelname
|
||||
if(vars.model == "NeoCustom" or vars.model == "GPT2Custom"):
|
||||
modelname = os.path.basename(os.path.normpath(vars.custmodpth))
|
||||
return modelname
|
||||
else:
|
||||
modelname = vars.model
|
||||
return modelname
|
||||
|
||||
#==================================================================#
|
||||
# Startup
|
||||
#==================================================================#
|
||||
|
@ -177,6 +190,8 @@ parser.add_argument("--breakmodel", action='store_true', help="For models that s
|
|||
parser.add_argument("--breakmodel_layers", type=int, help="Specify the number of layers to commit to system RAM if --breakmodel is used")
|
||||
parser.add_argument("--override_delete", action='store_true', help="Deleting stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow deleting stories if using --remote and prevent deleting stories otherwise.")
|
||||
parser.add_argument("--override_rename", action='store_true', help="Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise.")
|
||||
parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.")
|
||||
|
||||
args = parser.parse_args()
|
||||
vars.model = args.model;
|
||||
|
||||
|
@ -255,12 +270,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
|||
|
||||
# Ask for API key if InferKit was selected
|
||||
if(vars.model == "InferKit"):
|
||||
if(not path.exists("client.settings")):
|
||||
if(not path.exists("settings/" + getmodelname() + ".settings")):
|
||||
# If the client settings file doesn't exist, create it
|
||||
print("{0}Please enter your InferKit API key:{1}\n".format(colors.CYAN, colors.END))
|
||||
vars.apikey = input("Key> ")
|
||||
# Write API key to file
|
||||
file = open("client.settings", "w")
|
||||
file = open("settings/" + getmodelname() + ".settings", "w")
|
||||
try:
|
||||
js = {"apikey": vars.apikey}
|
||||
file.write(json.dumps(js, indent=3))
|
||||
|
@ -268,7 +283,7 @@ if(vars.model == "InferKit"):
|
|||
file.close()
|
||||
else:
|
||||
# Otherwise open it up
|
||||
file = open("client.settings", "r")
|
||||
file = open("settings/" + getmodelname() + ".settings", "r")
|
||||
# Check if API key exists
|
||||
js = json.load(file)
|
||||
if("apikey" in js and js["apikey"] != ""):
|
||||
|
@ -281,7 +296,7 @@ if(vars.model == "InferKit"):
|
|||
vars.apikey = input("Key> ")
|
||||
js["apikey"] = vars.apikey
|
||||
# Write API key to file
|
||||
file = open("client.settings", "w")
|
||||
file = open("settings/" + getmodelname() + ".settings", "w")
|
||||
try:
|
||||
file.write(json.dumps(js, indent=3))
|
||||
finally:
|
||||
|
@ -289,12 +304,12 @@ if(vars.model == "InferKit"):
|
|||
|
||||
# Ask for API key if OpenAI was selected
|
||||
if(vars.model == "OAI"):
|
||||
if(not path.exists("client.settings")):
|
||||
if(not path.exists("settings/" + getmodelname() + ".settings")):
|
||||
# If the client settings file doesn't exist, create it
|
||||
print("{0}Please enter your OpenAI API key:{1}\n".format(colors.CYAN, colors.END))
|
||||
vars.oaiapikey = input("Key> ")
|
||||
# Write API key to file
|
||||
file = open("client.settings", "w")
|
||||
file = open("settings/" + getmodelname() + ".settings", "w")
|
||||
try:
|
||||
js = {"oaiapikey": vars.oaiapikey}
|
||||
file.write(json.dumps(js, indent=3))
|
||||
|
@ -302,7 +317,7 @@ if(vars.model == "OAI"):
|
|||
file.close()
|
||||
else:
|
||||
# Otherwise open it up
|
||||
file = open("client.settings", "r")
|
||||
file = open("settings/" + getmodelname() + ".settings", "r")
|
||||
# Check if API key exists
|
||||
js = json.load(file)
|
||||
if("oaiapikey" in js and js["oaiapikey"] != ""):
|
||||
|
@ -315,7 +330,7 @@ if(vars.model == "OAI"):
|
|||
vars.oaiapikey = input("Key> ")
|
||||
js["oaiapikey"] = vars.oaiapikey
|
||||
# Write API key to file
|
||||
file = open("client.settings", "w")
|
||||
file = open("settings/" + getmodelname() + ".settings", "w")
|
||||
try:
|
||||
file.write(json.dumps(js, indent=3))
|
||||
finally:
|
||||
|
@ -818,10 +833,11 @@ def savesettings():
|
|||
js["widepth"] = vars.widepth
|
||||
js["useprompt"] = vars.useprompt
|
||||
js["adventure"] = vars.adventure
|
||||
js["msoverride"] = vars.msoverride
|
||||
|
||||
# Write it
|
||||
file = open("client.settings", "w")
|
||||
if not os.path.exists('settings'):
|
||||
os.mkdir('settings')
|
||||
file = open("settings/" + getmodelname() + ".settings", "w")
|
||||
try:
|
||||
file.write(json.dumps(js, indent=3))
|
||||
finally:
|
||||
|
@ -831,9 +847,9 @@ def savesettings():
|
|||
# Read settings from client file JSON and send to vars
|
||||
#==================================================================#
|
||||
def loadsettings():
|
||||
if(path.exists("client.settings")):
|
||||
if(path.exists("settings/" + getmodelname() + ".settings")):
|
||||
# Read file contents into JSON object
|
||||
file = open("client.settings", "r")
|
||||
file = open("settings/" + getmodelname() + ".settings", "r")
|
||||
js = json.load(file)
|
||||
|
||||
# Copy file contents to vars
|
||||
|
@ -867,8 +883,6 @@ def loadsettings():
|
|||
vars.useprompt = js["useprompt"]
|
||||
if("adventure" in js):
|
||||
vars.adventure = js["adventure"]
|
||||
if("msoverride" in js):
|
||||
vars.msoverride = js["msoverride"]
|
||||
|
||||
file.close()
|
||||
|
||||
|
@ -881,25 +895,18 @@ def loadmodelsettings():
|
|||
js = json.load(model_config)
|
||||
if("badwordsids" in js):
|
||||
vars.badwordsids = js["badwordsids"]
|
||||
if vars.msoverride:
|
||||
if("temp" in js):
|
||||
print("temp forced by model")
|
||||
vars.temp = js["temp"]
|
||||
if("top_p" in js):
|
||||
print("top_p forced by model")
|
||||
vars.top_p = js["top_p"]
|
||||
if("top_k" in js):
|
||||
print("top_k forced by model")
|
||||
vars.top_k = js["top_k"]
|
||||
if("tfs" in js):
|
||||
print("tfs forced by model")
|
||||
vars.tfs = js["tfs"]
|
||||
if("rep_pen" in js):
|
||||
print("Repetition Penalty forced by model")
|
||||
vars.rep_pen = js["rep_pen"]
|
||||
if("adventure" in js):
|
||||
print("Adventure mode enabled/disabled by model")
|
||||
vars.adventure = js["adventure"]
|
||||
if("temp" in js):
|
||||
vars.temp = js["temp"]
|
||||
if("top_p" in js):
|
||||
vars.top_p = js["top_p"]
|
||||
if("top_k" in js):
|
||||
vars.top_k = js["top_k"]
|
||||
if("tfs" in js):
|
||||
vars.tfs = js["tfs"]
|
||||
if("rep_pen" in js):
|
||||
vars.rep_pen = js["rep_pen"]
|
||||
if("adventure" in js):
|
||||
vars.adventure = js["adventure"]
|
||||
model_config.close()
|
||||
|
||||
#==================================================================#
|
||||
|
@ -2337,8 +2344,9 @@ def randomGameRequest(topic):
|
|||
if __name__ == "__main__":
|
||||
|
||||
# Load settings from client.settings
|
||||
loadsettings()
|
||||
loadmodelsettings()
|
||||
loadsettings()
|
||||
|
||||
# Start Flask/SocketIO (Blocking, so this must be last method!)
|
||||
|
||||
#socketio.run(app, host='0.0.0.0', port=5000)
|
||||
|
|
Loading…
Reference in New Issue