fix previously saved settings overwriting new API key
This commit is contained in:
parent
981acaef71
commit
4362ca4b34
140
aiserver.py
140
aiserver.py
|
@ -1108,38 +1108,41 @@ def loadsettings():
|
||||||
def processsettings(js):
|
def processsettings(js):
|
||||||
# Copy file contents to vars
|
# Copy file contents to vars
|
||||||
if("apikey" in js):
|
if("apikey" in js):
|
||||||
vars.apikey = js["apikey"]
|
# If the model is the HORDE, then previously saved API key in settings
|
||||||
|
# Will always override a new key set.
|
||||||
|
if vars.model != "CLUSTER" or vars.apikey == '':
|
||||||
|
vars.apikey = js["apikey"]
|
||||||
if("andepth" in js):
|
if("andepth" in js):
|
||||||
vars.andepth = js["andepth"]
|
vars.andepth = js["andepth"]
|
||||||
if("sampler_order" in js):
|
if("sampler_order" in js):
|
||||||
sampler_order = vars.sampler_order
|
sampler_order = vars.sampler_order
|
||||||
if(len(sampler_order) < 7):
|
if(len(sampler_order) < 7):
|
||||||
sampler_order = [6] + sampler_order
|
sampler_order = [6] + sampler_order
|
||||||
vars.sampler_order = sampler_order
|
vars.sampler_order = sampler_order
|
||||||
if("temp" in js):
|
if("temp" in js):
|
||||||
vars.temp = js["temp"]
|
vars.temp = js["temp"]
|
||||||
if("top_p" in js):
|
if("top_p" in js):
|
||||||
vars.top_p = js["top_p"]
|
vars.top_p = js["top_p"]
|
||||||
if("top_k" in js):
|
if("top_k" in js):
|
||||||
vars.top_k = js["top_k"]
|
vars.top_k = js["top_k"]
|
||||||
if("tfs" in js):
|
if("tfs" in js):
|
||||||
vars.tfs = js["tfs"]
|
vars.tfs = js["tfs"]
|
||||||
if("typical" in js):
|
if("typical" in js):
|
||||||
vars.typical = js["typical"]
|
vars.typical = js["typical"]
|
||||||
if("top_a" in js):
|
if("top_a" in js):
|
||||||
vars.top_a = js["top_a"]
|
vars.top_a = js["top_a"]
|
||||||
if("rep_pen" in js):
|
if("rep_pen" in js):
|
||||||
vars.rep_pen = js["rep_pen"]
|
vars.rep_pen = js["rep_pen"]
|
||||||
if("rep_pen_slope" in js):
|
if("rep_pen_slope" in js):
|
||||||
vars.rep_pen_slope = js["rep_pen_slope"]
|
vars.rep_pen_slope = js["rep_pen_slope"]
|
||||||
if("rep_pen_range" in js):
|
if("rep_pen_range" in js):
|
||||||
vars.rep_pen_range = js["rep_pen_range"]
|
vars.rep_pen_range = js["rep_pen_range"]
|
||||||
if("genamt" in js):
|
if("genamt" in js):
|
||||||
vars.genamt = js["genamt"]
|
vars.genamt = js["genamt"]
|
||||||
if("max_length" in js):
|
if("max_length" in js):
|
||||||
vars.max_length = js["max_length"]
|
vars.max_length = js["max_length"]
|
||||||
if("ikgen" in js):
|
if("ikgen" in js):
|
||||||
vars.ikgen = js["ikgen"]
|
vars.ikgen = js["ikgen"]
|
||||||
if("formatoptns" in js):
|
if("formatoptns" in js):
|
||||||
vars.formatoptns = js["formatoptns"]
|
vars.formatoptns = js["formatoptns"]
|
||||||
if("numseqs" in js):
|
if("numseqs" in js):
|
||||||
|
@ -1608,51 +1611,57 @@ def get_cluster_models(msg):
|
||||||
vars.oaiapikey = msg['key']
|
vars.oaiapikey = msg['key']
|
||||||
vars.apikey = vars.oaiapikey
|
vars.apikey = vars.oaiapikey
|
||||||
url = msg['url']
|
url = msg['url']
|
||||||
|
|
||||||
|
|
||||||
# Get list of models from public cluster
|
# Get list of models from public cluster
|
||||||
logger.init("KAI Horde Models", status="Retrieving")
|
logger.init("KAI Horde Models", status="Retrieving")
|
||||||
req = requests.get("{}/models".format(url))
|
try:
|
||||||
if(req.status_code == 200):
|
req = requests.get("{}/models".format(url))
|
||||||
engines = req.json()
|
except requests.exceptions.ConnectionError:
|
||||||
logger.debug(engines)
|
logger.init_err("KAI Horde Models", status="Failed")
|
||||||
try:
|
logger.error("Provided KoboldAI Horde URL unreachable")
|
||||||
engines = [[en, en] for en in engines]
|
emit('from_server', {'cmd': 'errmsg', 'data': "Provided KoboldAI Horde URL unreachable"})
|
||||||
except:
|
return
|
||||||
logger.error(engines)
|
if(not req.ok):
|
||||||
raise
|
|
||||||
|
|
||||||
online_model = ""
|
|
||||||
changed=False
|
|
||||||
|
|
||||||
#Save the key
|
|
||||||
if not path.exists("settings"):
|
|
||||||
# If the client settings file doesn't exist, create it
|
|
||||||
# Write API key to file
|
|
||||||
os.makedirs('settings', exist_ok=True)
|
|
||||||
if path.exists(get_config_filename(vars.model_selected)):
|
|
||||||
with open(get_config_filename(vars.model_selected), "r") as file:
|
|
||||||
js = json.load(file)
|
|
||||||
if 'online_model' in js:
|
|
||||||
online_model = js['online_model']
|
|
||||||
if "apikey" in js:
|
|
||||||
if js['apikey'] != vars.oaiapikey:
|
|
||||||
changed=True
|
|
||||||
else:
|
|
||||||
changed=True
|
|
||||||
if changed:
|
|
||||||
js={}
|
|
||||||
with open(get_config_filename(vars.model_selected), "w") as file:
|
|
||||||
js["apikey"] = vars.oaiapikey
|
|
||||||
file.write(json.dumps(js, indent=3))
|
|
||||||
|
|
||||||
logger.init_ok("KAI Horde Models", status="OK")
|
|
||||||
emit('from_server', {'cmd': 'oai_engines', 'data': engines, 'online_model': online_model}, broadcast=True)
|
|
||||||
else:
|
|
||||||
# Something went wrong, print the message and quit since we can't initialize an engine
|
# Something went wrong, print the message and quit since we can't initialize an engine
|
||||||
logger.init_err("KAI Horde Models", status="Failed")
|
logger.init_err("KAI Horde Models", status="Failed")
|
||||||
logger.error(req.json())
|
logger.error(req.json())
|
||||||
emit('from_server', {'cmd': 'errmsg', 'data': req.json()})
|
emit('from_server', {'cmd': 'errmsg', 'data': req.json()})
|
||||||
|
return
|
||||||
|
|
||||||
|
engines = req.json()
|
||||||
|
logger.debug(engines)
|
||||||
|
try:
|
||||||
|
engines = [[en, en] for en in engines]
|
||||||
|
except:
|
||||||
|
logger.error(engines)
|
||||||
|
raise
|
||||||
|
|
||||||
|
online_model = ""
|
||||||
|
changed=False
|
||||||
|
|
||||||
|
#Save the key
|
||||||
|
if not path.exists("settings"):
|
||||||
|
# If the client settings file doesn't exist, create it
|
||||||
|
# Write API key to file
|
||||||
|
os.makedirs('settings', exist_ok=True)
|
||||||
|
if path.exists(get_config_filename(vars.model_selected)):
|
||||||
|
with open(get_config_filename(vars.model_selected), "r") as file:
|
||||||
|
js = json.load(file)
|
||||||
|
if 'online_model' in js:
|
||||||
|
online_model = js['online_model']
|
||||||
|
if "apikey" in js:
|
||||||
|
if js['apikey'] != vars.oaiapikey:
|
||||||
|
changed=True
|
||||||
|
else:
|
||||||
|
changed=True
|
||||||
|
if changed:
|
||||||
|
js={}
|
||||||
|
with open(get_config_filename(vars.model_selected), "w") as file:
|
||||||
|
js["apikey"] = vars.oaiapikey
|
||||||
|
file.write(json.dumps(js, indent=3))
|
||||||
|
|
||||||
|
logger.init_ok("KAI Horde Models", status="OK")
|
||||||
|
emit('from_server', {'cmd': 'oai_engines', 'data': engines, 'online_model': online_model}, broadcast=True)
|
||||||
|
|
||||||
|
|
||||||
# Function to patch transformers to use our soft prompt
|
# Function to patch transformers to use our soft prompt
|
||||||
def patch_causallm(model):
|
def patch_causallm(model):
|
||||||
|
@ -3791,7 +3800,7 @@ def get_message(msg):
|
||||||
elif(msg['cmd'] == 'list_model'):
|
elif(msg['cmd'] == 'list_model'):
|
||||||
sendModelSelection(menu=msg['data'])
|
sendModelSelection(menu=msg['data'])
|
||||||
elif(msg['cmd'] == 'load_model'):
|
elif(msg['cmd'] == 'load_model'):
|
||||||
logger.debug(vars.model_selected)
|
logger.debug(f"Selected Model: {vars.model_selected}")
|
||||||
if not os.path.exists("settings/"):
|
if not os.path.exists("settings/"):
|
||||||
os.mkdir("settings")
|
os.mkdir("settings")
|
||||||
changed = True
|
changed = True
|
||||||
|
@ -5159,7 +5168,6 @@ def sendtocluster(txt, min, max):
|
||||||
|
|
||||||
# Store context in memory to use it for comparison with generated content
|
# Store context in memory to use it for comparison with generated content
|
||||||
vars.lastctx = txt
|
vars.lastctx = txt
|
||||||
|
|
||||||
# Build request JSON data
|
# Build request JSON data
|
||||||
reqdata = {
|
reqdata = {
|
||||||
'max_length': max - min + 1,
|
'max_length': max - min + 1,
|
||||||
|
@ -5181,37 +5189,39 @@ def sendtocluster(txt, min, max):
|
||||||
'api_key': vars.apikey,
|
'api_key': vars.apikey,
|
||||||
'models': vars.cluster_requested_models,
|
'models': vars.cluster_requested_models,
|
||||||
}
|
}
|
||||||
|
logger.debug(f"Horde Payload: {cluster_metadata}")
|
||||||
try:
|
try:
|
||||||
# Create request
|
# Create request
|
||||||
req = requests.post(
|
req = requests.post(
|
||||||
vars.colaburl[:-8] + "/api/v1/generate/sync",
|
vars.colaburl[:-8] + "/api/v1/generate/sync",
|
||||||
json=cluster_metadata,
|
json=cluster_metadata,
|
||||||
)
|
)
|
||||||
js = req.json()
|
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
errmsg = f"Horde unavailable. Please try again later"
|
errmsg = f"Horde unavailable. Please try again later"
|
||||||
logger.error(errmsg)
|
logger.error(errmsg)
|
||||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
||||||
set_aibusy(0)
|
set_aibusy(0)
|
||||||
return
|
return
|
||||||
|
if(req.status_code == 503):
|
||||||
|
errmsg = f"KoboldAI API Error: No available KoboldAI servers found in Horde to fulfil this request using the selected models or other properties."
|
||||||
|
logger.error(req.text)
|
||||||
|
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
||||||
|
set_aibusy(0)
|
||||||
|
return
|
||||||
|
if(not req.ok):
|
||||||
|
errmsg = f"KoboldAI API Error: Failed to get a standard reply from the Horde. Please check the console."
|
||||||
|
logger.error(req.text)
|
||||||
|
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
||||||
|
set_aibusy(0)
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
js = req.json()
|
||||||
except requests.exceptions.JSONDecodeError:
|
except requests.exceptions.JSONDecodeError:
|
||||||
errmsg = f"Unexpected message received from the Horde: '{req.text}'"
|
errmsg = f"Unexpected message received from the Horde: '{req.text}'"
|
||||||
logger.error(errmsg)
|
logger.error(errmsg)
|
||||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
||||||
set_aibusy(0)
|
set_aibusy(0)
|
||||||
return
|
return
|
||||||
if(req.status_code == 503):
|
|
||||||
errmsg = f"KoboldAI API Error: No available KoboldAI servers found in Horde to fulfil this request using the selected models or other properties."
|
|
||||||
logger.error(json.dumps(js))
|
|
||||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
|
||||||
set_aibusy(0)
|
|
||||||
return
|
|
||||||
if(req.status_code != 200):
|
|
||||||
errmsg = f"KoboldAI API Error: Failed to get a standard reply from the Horde. Please check the console."
|
|
||||||
logger.error(json.dumps(js))
|
|
||||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
|
||||||
set_aibusy(0)
|
|
||||||
return
|
|
||||||
gen_servers = [(cgen['server_name'],cgen['server_id']) for cgen in js]
|
gen_servers = [(cgen['server_name'],cgen['server_id']) for cgen in js]
|
||||||
logger.info(f"Generations by: {gen_servers}")
|
logger.info(f"Generations by: {gen_servers}")
|
||||||
# Just in case we want to announce it to the user
|
# Just in case we want to announce it to the user
|
||||||
|
|
Loading…
Reference in New Issue