Merge pull request #245 from db0/kaimergemain2
Makes prod version of KAI work with merged hordes in stablehorde.net
This commit is contained in:
commit
750cc3d2dc
92
aiserver.py
92
aiserver.py
|
@ -1511,7 +1511,7 @@ def get_model_info(model, directory=""):
|
|||
models_on_url = True
|
||||
url = True
|
||||
key = True
|
||||
default_url = 'https://koboldai.net'
|
||||
default_url = 'https://stablehorde.net'
|
||||
multi_online_models = True
|
||||
if path.exists(get_config_filename(model)):
|
||||
with open(get_config_filename(model), "r") as file:
|
||||
|
@ -1671,7 +1671,7 @@ def get_cluster_models(msg):
|
|||
# Get list of models from public cluster
|
||||
logger.init("KAI Horde Models", status="Retrieving")
|
||||
try:
|
||||
req = requests.get("{}/api/v1/models".format(url))
|
||||
req = requests.get(f"{url}/api/v2/status/models?type=text")
|
||||
except requests.exceptions.ConnectionError:
|
||||
logger.init_err("KAI Horde Models", status="Failed")
|
||||
logger.error("Provided KoboldAI Horde URL unreachable")
|
||||
|
@ -1687,10 +1687,11 @@ def get_cluster_models(msg):
|
|||
engines = req.json()
|
||||
logger.debug(engines)
|
||||
try:
|
||||
engines = [[en, en] for en in engines]
|
||||
engines = [[en["name"], en["name"]] for en in engines]
|
||||
except:
|
||||
logger.error(engines)
|
||||
raise
|
||||
logger.debug(engines)
|
||||
|
||||
online_model = ""
|
||||
changed=False
|
||||
|
@ -5269,15 +5270,21 @@ def sendtocluster(txt, min, max):
|
|||
cluster_metadata = {
|
||||
'prompt': txt,
|
||||
'params': reqdata,
|
||||
'api_key': vars.apikey,
|
||||
'models': vars.cluster_requested_models,
|
||||
}
|
||||
'trusted_workers': False,
|
||||
}
|
||||
client_agent = "KoboldAI:1.19.3:(discord)Henky!!#2205"
|
||||
cluster_headers = {
|
||||
'apikey': vars.apikey,
|
||||
"Client-Agent": client_agent
|
||||
}
|
||||
logger.debug(f"Horde Payload: {cluster_metadata}")
|
||||
try:
|
||||
# Create request
|
||||
req = requests.post(
|
||||
vars.colaburl[:-8] + "/api/v1/generate/sync",
|
||||
vars.colaburl[:-8] + "/api/v2/generate/text/async",
|
||||
json=cluster_metadata,
|
||||
headers=cluster_headers,
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
errmsg = f"Horde unavailable. Please try again later"
|
||||
|
@ -5305,13 +5312,76 @@ def sendtocluster(txt, min, max):
|
|||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
||||
set_aibusy(0)
|
||||
return
|
||||
gen_servers = [(cgen['server_name'],cgen['server_id']) for cgen in js]
|
||||
logger.info(f"Generations by: {gen_servers}")
|
||||
|
||||
request_id = js["id"]
|
||||
logger.debug("Horde Request ID: {}".format(request_id))
|
||||
|
||||
cluster_agent_headers = {
|
||||
"Client-Agent": client_agent
|
||||
}
|
||||
finished = False
|
||||
|
||||
while not finished:
|
||||
try:
|
||||
req = requests.get(vars.colaburl[:-8] + "/api/v2/generate/text/status/" + request_id, headers=cluster_agent_headers)
|
||||
except requests.exceptions.ConnectionError:
|
||||
errmsg = f"Horde unavailable. Please try again later"
|
||||
logger.error(errmsg)
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
||||
set_aibusy(0)
|
||||
return
|
||||
|
||||
if not req.ok:
|
||||
errmsg = f"KoboldAI API Error: Failed to get a standard reply from the Horde. Please check the console."
|
||||
logger.error(req.text)
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
||||
set_aibusy(0)
|
||||
return
|
||||
|
||||
try:
|
||||
req_status = req.json()
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
errmsg = f"Unexpected message received from the KoboldAI Horde: '{req.text}'"
|
||||
logger.error(errmsg)
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
||||
set_aibusy(0)
|
||||
return
|
||||
|
||||
if "done" not in req_status:
|
||||
errmsg = f"Unexpected response received from the KoboldAI Horde: '{js}'"
|
||||
logger.error(errmsg)
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
||||
set_aibusy(0)
|
||||
return
|
||||
|
||||
finished = req_status["done"]
|
||||
|
||||
if not finished:
|
||||
logger.debug(req_status)
|
||||
time.sleep(1)
|
||||
|
||||
logger.debug("Last Horde Status Message: {}".format(js))
|
||||
if req_status["faulted"]:
|
||||
errmsg = "Horde Text generation faulted! Please try again"
|
||||
logger.error(errmsg)
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
||||
set_aibusy(0)
|
||||
return
|
||||
|
||||
generations = req_status['generations']
|
||||
gen_workers = [(cgen['worker_name'],cgen['worker_id']) for cgen in generations]
|
||||
logger.info(f"Generations by: {gen_workers}")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# Just in case we want to announce it to the user
|
||||
if len(js) == 1:
|
||||
warnmsg = f"Text generated by {js[0]['server_name']}"
|
||||
if len(generations) == 1:
|
||||
warnmsg = f"Text generated by {[w[0] for w in gen_workers]}"
|
||||
emit('from_server', {'cmd': 'warnmsg', 'data': warnmsg}, broadcast=True)
|
||||
genout = [cgen['text'] for cgen in js]
|
||||
genout = [cgen['text'] for cgen in generations]
|
||||
|
||||
for i in range(vars.numseqs):
|
||||
vars.lua_koboldbridge.outputs[i+1] = genout[i]
|
||||
|
|
Loading…
Reference in New Issue