mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Settings per Model
Models can no longer override client settings, instead settings are now saved on a model per model basis with the settings provided by the model being the default. Users can also specify the desired configuration name as a command line parameter to avoid conflicting file names (Such as all Colabs having Colab.settings by default).
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -3,6 +3,7 @@ client.settings
|
|||||||
|
|
||||||
# Ignore stories file except for test_story
|
# Ignore stories file except for test_story
|
||||||
stories/*
|
stories/*
|
||||||
|
settings/*
|
||||||
!stories/sample_story.json
|
!stories/sample_story.json
|
||||||
/.project
|
/.project
|
||||||
*.bak
|
*.bak
|
||||||
|
54
aiserver.py
54
aiserver.py
@ -117,7 +117,6 @@ class vars:
|
|||||||
actionmode = 1
|
actionmode = 1
|
||||||
adventure = False
|
adventure = False
|
||||||
remote = False
|
remote = False
|
||||||
msoverride = True # Allow models to override settings
|
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Function to get model selection at startup
|
# Function to get model selection at startup
|
||||||
@ -163,6 +162,20 @@ def gettokenids(char):
|
|||||||
keys.append(key)
|
keys.append(key)
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Return Model Name
|
||||||
|
#==================================================================#
|
||||||
|
def getmodelname():
|
||||||
|
if(args.configname):
|
||||||
|
modelname = args.configname
|
||||||
|
return modelname
|
||||||
|
if(vars.model == "NeoCustom" or vars.model == "GPT2Custom"):
|
||||||
|
modelname = os.path.basename(os.path.normpath(vars.custmodpth))
|
||||||
|
return modelname
|
||||||
|
else:
|
||||||
|
modelname = vars.model
|
||||||
|
return modelname
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Startup
|
# Startup
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
@ -177,6 +190,8 @@ parser.add_argument("--breakmodel", action='store_true', help="For models that s
|
|||||||
parser.add_argument("--breakmodel_layers", type=int, help="Specify the number of layers to commit to system RAM if --breakmodel is used")
|
parser.add_argument("--breakmodel_layers", type=int, help="Specify the number of layers to commit to system RAM if --breakmodel is used")
|
||||||
parser.add_argument("--override_delete", action='store_true', help="Deleting stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow deleting stories if using --remote and prevent deleting stories otherwise.")
|
parser.add_argument("--override_delete", action='store_true', help="Deleting stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow deleting stories if using --remote and prevent deleting stories otherwise.")
|
||||||
parser.add_argument("--override_rename", action='store_true', help="Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise.")
|
parser.add_argument("--override_rename", action='store_true', help="Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise.")
|
||||||
|
parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
vars.model = args.model;
|
vars.model = args.model;
|
||||||
|
|
||||||
@ -255,12 +270,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
|||||||
|
|
||||||
# Ask for API key if InferKit was selected
|
# Ask for API key if InferKit was selected
|
||||||
if(vars.model == "InferKit"):
|
if(vars.model == "InferKit"):
|
||||||
if(not path.exists("client.settings")):
|
if(not path.exists("settings/" + getmodelname() + ".settings")):
|
||||||
# If the client settings file doesn't exist, create it
|
# If the client settings file doesn't exist, create it
|
||||||
print("{0}Please enter your InferKit API key:{1}\n".format(colors.CYAN, colors.END))
|
print("{0}Please enter your InferKit API key:{1}\n".format(colors.CYAN, colors.END))
|
||||||
vars.apikey = input("Key> ")
|
vars.apikey = input("Key> ")
|
||||||
# Write API key to file
|
# Write API key to file
|
||||||
file = open("client.settings", "w")
|
file = open("settings/" + getmodelname() + ".settings", "w")
|
||||||
try:
|
try:
|
||||||
js = {"apikey": vars.apikey}
|
js = {"apikey": vars.apikey}
|
||||||
file.write(json.dumps(js, indent=3))
|
file.write(json.dumps(js, indent=3))
|
||||||
@ -268,7 +283,7 @@ if(vars.model == "InferKit"):
|
|||||||
file.close()
|
file.close()
|
||||||
else:
|
else:
|
||||||
# Otherwise open it up
|
# Otherwise open it up
|
||||||
file = open("client.settings", "r")
|
file = open("settings/" + getmodelname() + ".settings", "r")
|
||||||
# Check if API key exists
|
# Check if API key exists
|
||||||
js = json.load(file)
|
js = json.load(file)
|
||||||
if("apikey" in js and js["apikey"] != ""):
|
if("apikey" in js and js["apikey"] != ""):
|
||||||
@ -281,7 +296,7 @@ if(vars.model == "InferKit"):
|
|||||||
vars.apikey = input("Key> ")
|
vars.apikey = input("Key> ")
|
||||||
js["apikey"] = vars.apikey
|
js["apikey"] = vars.apikey
|
||||||
# Write API key to file
|
# Write API key to file
|
||||||
file = open("client.settings", "w")
|
file = open("settings/" + getmodelname() + ".settings", "w")
|
||||||
try:
|
try:
|
||||||
file.write(json.dumps(js, indent=3))
|
file.write(json.dumps(js, indent=3))
|
||||||
finally:
|
finally:
|
||||||
@ -289,12 +304,12 @@ if(vars.model == "InferKit"):
|
|||||||
|
|
||||||
# Ask for API key if OpenAI was selected
|
# Ask for API key if OpenAI was selected
|
||||||
if(vars.model == "OAI"):
|
if(vars.model == "OAI"):
|
||||||
if(not path.exists("client.settings")):
|
if(not path.exists("settings/" + getmodelname() + ".settings")):
|
||||||
# If the client settings file doesn't exist, create it
|
# If the client settings file doesn't exist, create it
|
||||||
print("{0}Please enter your OpenAI API key:{1}\n".format(colors.CYAN, colors.END))
|
print("{0}Please enter your OpenAI API key:{1}\n".format(colors.CYAN, colors.END))
|
||||||
vars.oaiapikey = input("Key> ")
|
vars.oaiapikey = input("Key> ")
|
||||||
# Write API key to file
|
# Write API key to file
|
||||||
file = open("client.settings", "w")
|
file = open("settings/" + getmodelname() + ".settings", "w")
|
||||||
try:
|
try:
|
||||||
js = {"oaiapikey": vars.oaiapikey}
|
js = {"oaiapikey": vars.oaiapikey}
|
||||||
file.write(json.dumps(js, indent=3))
|
file.write(json.dumps(js, indent=3))
|
||||||
@ -302,7 +317,7 @@ if(vars.model == "OAI"):
|
|||||||
file.close()
|
file.close()
|
||||||
else:
|
else:
|
||||||
# Otherwise open it up
|
# Otherwise open it up
|
||||||
file = open("client.settings", "r")
|
file = open("settings/" + getmodelname() + ".settings", "r")
|
||||||
# Check if API key exists
|
# Check if API key exists
|
||||||
js = json.load(file)
|
js = json.load(file)
|
||||||
if("oaiapikey" in js and js["oaiapikey"] != ""):
|
if("oaiapikey" in js and js["oaiapikey"] != ""):
|
||||||
@ -315,7 +330,7 @@ if(vars.model == "OAI"):
|
|||||||
vars.oaiapikey = input("Key> ")
|
vars.oaiapikey = input("Key> ")
|
||||||
js["oaiapikey"] = vars.oaiapikey
|
js["oaiapikey"] = vars.oaiapikey
|
||||||
# Write API key to file
|
# Write API key to file
|
||||||
file = open("client.settings", "w")
|
file = open("settings/" + getmodelname() + ".settings", "w")
|
||||||
try:
|
try:
|
||||||
file.write(json.dumps(js, indent=3))
|
file.write(json.dumps(js, indent=3))
|
||||||
finally:
|
finally:
|
||||||
@ -818,10 +833,11 @@ def savesettings():
|
|||||||
js["widepth"] = vars.widepth
|
js["widepth"] = vars.widepth
|
||||||
js["useprompt"] = vars.useprompt
|
js["useprompt"] = vars.useprompt
|
||||||
js["adventure"] = vars.adventure
|
js["adventure"] = vars.adventure
|
||||||
js["msoverride"] = vars.msoverride
|
|
||||||
|
|
||||||
# Write it
|
# Write it
|
||||||
file = open("client.settings", "w")
|
if not os.path.exists('settings'):
|
||||||
|
os.mkdir('settings')
|
||||||
|
file = open("settings/" + getmodelname() + ".settings", "w")
|
||||||
try:
|
try:
|
||||||
file.write(json.dumps(js, indent=3))
|
file.write(json.dumps(js, indent=3))
|
||||||
finally:
|
finally:
|
||||||
@ -831,9 +847,9 @@ def savesettings():
|
|||||||
# Read settings from client file JSON and send to vars
|
# Read settings from client file JSON and send to vars
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
def loadsettings():
|
def loadsettings():
|
||||||
if(path.exists("client.settings")):
|
if(path.exists("settings/" + getmodelname() + ".settings")):
|
||||||
# Read file contents into JSON object
|
# Read file contents into JSON object
|
||||||
file = open("client.settings", "r")
|
file = open("settings/" + getmodelname() + ".settings", "r")
|
||||||
js = json.load(file)
|
js = json.load(file)
|
||||||
|
|
||||||
# Copy file contents to vars
|
# Copy file contents to vars
|
||||||
@ -867,8 +883,6 @@ def loadsettings():
|
|||||||
vars.useprompt = js["useprompt"]
|
vars.useprompt = js["useprompt"]
|
||||||
if("adventure" in js):
|
if("adventure" in js):
|
||||||
vars.adventure = js["adventure"]
|
vars.adventure = js["adventure"]
|
||||||
if("msoverride" in js):
|
|
||||||
vars.msoverride = js["msoverride"]
|
|
||||||
|
|
||||||
file.close()
|
file.close()
|
||||||
|
|
||||||
@ -881,24 +895,17 @@ def loadmodelsettings():
|
|||||||
js = json.load(model_config)
|
js = json.load(model_config)
|
||||||
if("badwordsids" in js):
|
if("badwordsids" in js):
|
||||||
vars.badwordsids = js["badwordsids"]
|
vars.badwordsids = js["badwordsids"]
|
||||||
if vars.msoverride:
|
|
||||||
if("temp" in js):
|
if("temp" in js):
|
||||||
print("temp forced by model")
|
|
||||||
vars.temp = js["temp"]
|
vars.temp = js["temp"]
|
||||||
if("top_p" in js):
|
if("top_p" in js):
|
||||||
print("top_p forced by model")
|
|
||||||
vars.top_p = js["top_p"]
|
vars.top_p = js["top_p"]
|
||||||
if("top_k" in js):
|
if("top_k" in js):
|
||||||
print("top_k forced by model")
|
|
||||||
vars.top_k = js["top_k"]
|
vars.top_k = js["top_k"]
|
||||||
if("tfs" in js):
|
if("tfs" in js):
|
||||||
print("tfs forced by model")
|
|
||||||
vars.tfs = js["tfs"]
|
vars.tfs = js["tfs"]
|
||||||
if("rep_pen" in js):
|
if("rep_pen" in js):
|
||||||
print("Repetition Penalty forced by model")
|
|
||||||
vars.rep_pen = js["rep_pen"]
|
vars.rep_pen = js["rep_pen"]
|
||||||
if("adventure" in js):
|
if("adventure" in js):
|
||||||
print("Adventure mode enabled/disabled by model")
|
|
||||||
vars.adventure = js["adventure"]
|
vars.adventure = js["adventure"]
|
||||||
model_config.close()
|
model_config.close()
|
||||||
|
|
||||||
@ -2337,8 +2344,9 @@ def randomGameRequest(topic):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
# Load settings from client.settings
|
# Load settings from client.settings
|
||||||
loadsettings()
|
|
||||||
loadmodelsettings()
|
loadmodelsettings()
|
||||||
|
loadsettings()
|
||||||
|
|
||||||
# Start Flask/SocketIO (Blocking, so this must be last method!)
|
# Start Flask/SocketIO (Blocking, so this must be last method!)
|
||||||
|
|
||||||
#socketio.run(app, host='0.0.0.0', port=5000)
|
#socketio.run(app, host='0.0.0.0', port=5000)
|
||||||
|
Reference in New Issue
Block a user