fix previously saved settings overwriting new API key

This commit is contained in:
Divided by Zer0 2022-09-16 16:23:04 +02:00
parent 981acaef71
commit 4362ca4b34
1 changed files with 75 additions and 65 deletions

View File

@ -1108,38 +1108,41 @@ def loadsettings():
def processsettings(js): def processsettings(js):
# Copy file contents to vars # Copy file contents to vars
if("apikey" in js): if("apikey" in js):
vars.apikey = js["apikey"] # If the model is the HORDE, then previously saved API key in settings
# Will always override a new key set.
if vars.model != "CLUSTER" or vars.apikey == '':
vars.apikey = js["apikey"]
if("andepth" in js): if("andepth" in js):
vars.andepth = js["andepth"] vars.andepth = js["andepth"]
if("sampler_order" in js): if("sampler_order" in js):
sampler_order = vars.sampler_order sampler_order = vars.sampler_order
if(len(sampler_order) < 7): if(len(sampler_order) < 7):
sampler_order = [6] + sampler_order sampler_order = [6] + sampler_order
vars.sampler_order = sampler_order vars.sampler_order = sampler_order
if("temp" in js): if("temp" in js):
vars.temp = js["temp"] vars.temp = js["temp"]
if("top_p" in js): if("top_p" in js):
vars.top_p = js["top_p"] vars.top_p = js["top_p"]
if("top_k" in js): if("top_k" in js):
vars.top_k = js["top_k"] vars.top_k = js["top_k"]
if("tfs" in js): if("tfs" in js):
vars.tfs = js["tfs"] vars.tfs = js["tfs"]
if("typical" in js): if("typical" in js):
vars.typical = js["typical"] vars.typical = js["typical"]
if("top_a" in js): if("top_a" in js):
vars.top_a = js["top_a"] vars.top_a = js["top_a"]
if("rep_pen" in js): if("rep_pen" in js):
vars.rep_pen = js["rep_pen"] vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js): if("rep_pen_slope" in js):
vars.rep_pen_slope = js["rep_pen_slope"] vars.rep_pen_slope = js["rep_pen_slope"]
if("rep_pen_range" in js): if("rep_pen_range" in js):
vars.rep_pen_range = js["rep_pen_range"] vars.rep_pen_range = js["rep_pen_range"]
if("genamt" in js): if("genamt" in js):
vars.genamt = js["genamt"] vars.genamt = js["genamt"]
if("max_length" in js): if("max_length" in js):
vars.max_length = js["max_length"] vars.max_length = js["max_length"]
if("ikgen" in js): if("ikgen" in js):
vars.ikgen = js["ikgen"] vars.ikgen = js["ikgen"]
if("formatoptns" in js): if("formatoptns" in js):
vars.formatoptns = js["formatoptns"] vars.formatoptns = js["formatoptns"]
if("numseqs" in js): if("numseqs" in js):
@ -1608,51 +1611,57 @@ def get_cluster_models(msg):
vars.oaiapikey = msg['key'] vars.oaiapikey = msg['key']
vars.apikey = vars.oaiapikey vars.apikey = vars.oaiapikey
url = msg['url'] url = msg['url']
# Get list of models from public cluster # Get list of models from public cluster
logger.init("KAI Horde Models", status="Retrieving") logger.init("KAI Horde Models", status="Retrieving")
req = requests.get("{}/models".format(url)) try:
if(req.status_code == 200): req = requests.get("{}/models".format(url))
engines = req.json() except requests.exceptions.ConnectionError:
logger.debug(engines) logger.init_err("KAI Horde Models", status="Failed")
try: logger.error("Provided KoboldAI Horde URL unreachable")
engines = [[en, en] for en in engines] emit('from_server', {'cmd': 'errmsg', 'data': "Provided KoboldAI Horde URL unreachable"})
except: return
logger.error(engines) if(not req.ok):
raise
online_model = ""
changed=False
#Save the key
if not path.exists("settings"):
# If the client settings file doesn't exist, create it
# Write API key to file
os.makedirs('settings', exist_ok=True)
if path.exists(get_config_filename(vars.model_selected)):
with open(get_config_filename(vars.model_selected), "r") as file:
js = json.load(file)
if 'online_model' in js:
online_model = js['online_model']
if "apikey" in js:
if js['apikey'] != vars.oaiapikey:
changed=True
else:
changed=True
if changed:
js={}
with open(get_config_filename(vars.model_selected), "w") as file:
js["apikey"] = vars.oaiapikey
file.write(json.dumps(js, indent=3))
logger.init_ok("KAI Horde Models", status="OK")
emit('from_server', {'cmd': 'oai_engines', 'data': engines, 'online_model': online_model}, broadcast=True)
else:
# Something went wrong, print the message and quit since we can't initialize an engine # Something went wrong, print the message and quit since we can't initialize an engine
logger.init_err("KAI Horde Models", status="Failed") logger.init_err("KAI Horde Models", status="Failed")
logger.error(req.json()) logger.error(req.json())
emit('from_server', {'cmd': 'errmsg', 'data': req.json()}) emit('from_server', {'cmd': 'errmsg', 'data': req.json()})
return
engines = req.json()
logger.debug(engines)
try:
engines = [[en, en] for en in engines]
except:
logger.error(engines)
raise
online_model = ""
changed=False
#Save the key
if not path.exists("settings"):
# If the client settings file doesn't exist, create it
# Write API key to file
os.makedirs('settings', exist_ok=True)
if path.exists(get_config_filename(vars.model_selected)):
with open(get_config_filename(vars.model_selected), "r") as file:
js = json.load(file)
if 'online_model' in js:
online_model = js['online_model']
if "apikey" in js:
if js['apikey'] != vars.oaiapikey:
changed=True
else:
changed=True
if changed:
js={}
with open(get_config_filename(vars.model_selected), "w") as file:
js["apikey"] = vars.oaiapikey
file.write(json.dumps(js, indent=3))
logger.init_ok("KAI Horde Models", status="OK")
emit('from_server', {'cmd': 'oai_engines', 'data': engines, 'online_model': online_model}, broadcast=True)
# Function to patch transformers to use our soft prompt # Function to patch transformers to use our soft prompt
def patch_causallm(model): def patch_causallm(model):
@ -3791,7 +3800,7 @@ def get_message(msg):
elif(msg['cmd'] == 'list_model'): elif(msg['cmd'] == 'list_model'):
sendModelSelection(menu=msg['data']) sendModelSelection(menu=msg['data'])
elif(msg['cmd'] == 'load_model'): elif(msg['cmd'] == 'load_model'):
logger.debug(vars.model_selected) logger.debug(f"Selected Model: {vars.model_selected}")
if not os.path.exists("settings/"): if not os.path.exists("settings/"):
os.mkdir("settings") os.mkdir("settings")
changed = True changed = True
@ -5159,7 +5168,6 @@ def sendtocluster(txt, min, max):
# Store context in memory to use it for comparison with generated content # Store context in memory to use it for comparison with generated content
vars.lastctx = txt vars.lastctx = txt
# Build request JSON data # Build request JSON data
reqdata = { reqdata = {
'max_length': max - min + 1, 'max_length': max - min + 1,
@ -5181,37 +5189,39 @@ def sendtocluster(txt, min, max):
'api_key': vars.apikey, 'api_key': vars.apikey,
'models': vars.cluster_requested_models, 'models': vars.cluster_requested_models,
} }
logger.debug(f"Horde Payload: {cluster_metadata}")
try: try:
# Create request # Create request
req = requests.post( req = requests.post(
vars.colaburl[:-8] + "/api/v1/generate/sync", vars.colaburl[:-8] + "/api/v1/generate/sync",
json=cluster_metadata, json=cluster_metadata,
) )
js = req.json()
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
errmsg = f"Horde unavailable. Please try again later" errmsg = f"Horde unavailable. Please try again later"
logger.error(errmsg) logger.error(errmsg)
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True) emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
set_aibusy(0) set_aibusy(0)
return return
if(req.status_code == 503):
errmsg = f"KoboldAI API Error: No available KoboldAI servers found in Horde to fulfil this request using the selected models or other properties."
logger.error(req.text)
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
set_aibusy(0)
return
if(not req.ok):
errmsg = f"KoboldAI API Error: Failed to get a standard reply from the Horde. Please check the console."
logger.error(req.text)
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
set_aibusy(0)
return
try:
js = req.json()
except requests.exceptions.JSONDecodeError: except requests.exceptions.JSONDecodeError:
errmsg = f"Unexpected message received from the Horde: '{req.text}'" errmsg = f"Unexpected message received from the Horde: '{req.text}'"
logger.error(errmsg) logger.error(errmsg)
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True) emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
set_aibusy(0) set_aibusy(0)
return return
if(req.status_code == 503):
errmsg = f"KoboldAI API Error: No available KoboldAI servers found in Horde to fulfil this request using the selected models or other properties."
logger.error(json.dumps(js))
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
set_aibusy(0)
return
if(req.status_code != 200):
errmsg = f"KoboldAI API Error: Failed to get a standard reply from the Horde. Please check the console."
logger.error(json.dumps(js))
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
set_aibusy(0)
return
gen_servers = [(cgen['server_name'],cgen['server_id']) for cgen in js] gen_servers = [(cgen['server_name'],cgen['server_id']) for cgen in js]
logger.info(f"Generations by: {gen_servers}") logger.info(f"Generations by: {gen_servers}")
# Just in case we want to announce it to the user # Just in case we want to announce it to the user