From 2de9672b954c6d34f83c3fdc6f00476a222a6f31 Mon Sep 17 00:00:00 2001 From: Divided by Zer0 Date: Thu, 23 Feb 2023 18:27:11 +0100 Subject: [PATCH] attempt1 --- aiserver.py | 90 +++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 80 insertions(+), 10 deletions(-) diff --git a/aiserver.py b/aiserver.py index 665b43f6..92769066 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1671,7 +1671,7 @@ def get_cluster_models(msg): # Get list of models from public cluster logger.init("KAI Horde Models", status="Retrieving") try: - req = requests.get("{}/api/v1/models".format(url)) + req = requests.get(f"{url}/api/v2/status/models?type=text") except requests.exceptions.ConnectionError: logger.init_err("KAI Horde Models", status="Failed") logger.error("Provided KoboldAI Horde URL unreachable") @@ -1687,10 +1687,11 @@ def get_cluster_models(msg): engines = req.json() logger.debug(engines) try: - engines = [[en, en] for en in engines] + engines = [[en["name"], en["name"]] for en in engines] except: logger.error(engines) raise + logger.debug(engines) online_model = "" changed=False @@ -5269,15 +5270,21 @@ def sendtocluster(txt, min, max): cluster_metadata = { 'prompt': txt, 'params': reqdata, - 'api_key': vars.apikey, 'models': vars.cluster_requested_models, - } + 'trusted_workers': False, + } + client_agent = "KoboldAI:1.19.3:(discord)Henky!!#2205" + cluster_headers = { + 'apikey': koboldai_vars.apikey, + "Client-Agent": client_agent + } logger.debug(f"Horde Payload: {cluster_metadata}") try: # Create request req = requests.post( - vars.colaburl[:-8] + "/api/v1/generate/sync", + vars.colaburl[:-8] + "/api/v2/generate/text/async", json=cluster_metadata, + headers=cluster_headers, ) except requests.exceptions.ConnectionError: errmsg = f"Horde unavailable. Please try again later" @@ -5305,13 +5312,76 @@ def sendtocluster(txt, min, max): emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True) set_aibusy(0) return - gen_servers = [(cgen['server_name'],cgen['server_id']) for cgen in js] - logger.info(f"Generations by: {gen_servers}") + + request_id = js["id"] + logger.debug("Horde Request ID: {}".format(request_id)) + + cluster_agent_headers = { + "Client-Agent": client_agent + } + finished = False + + while not finished: + try: + req = requests.get(koboldai_vars.colaburl[:-8] + "/api/v2/generate/text/status/" + request_id, headers=cluster_agent_headers) + except requests.exceptions.ConnectionError: + errmsg = f"Horde unavailable. Please try again later" + logger.error(errmsg) + emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True) + set_aibusy(0) + return + + if not req.ok: + errmsg = f"KoboldAI API Error: Failed to get a standard reply from the Horde. Please check the console." + logger.error(req.text) + emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True) + set_aibusy(0) + return + + try: + req_status = req.json() + except requests.exceptions.JSONDecodeError: + errmsg = f"Unexpected message received from the KoboldAI Horde: '{req.text}'" + logger.error(errmsg) + emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True) + set_aibusy(0) + return + + if "done" not in req_status: + errmsg = f"Unexpected response received from the KoboldAI Horde: '{js}'" + logger.error(errmsg) + emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True) + set_aibusy(0) + return + + finished = req_status["done"] + + if not finished: + logger.debug(req_status) + time.sleep(1) + + logger.debug("Last Horde Status Message: {}".format(js)) + if req_status["faulted"]: + errmsg = "Horde Text generation faulted! Please try again" + logger.error(errmsg) + emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True) + set_aibusy(0) + return + + generations = req_status['generations'] + gen_workers = [(cgen['worker_name'],cgen['worker_id']) for cgen in generations] + logger.info(f"Generations by: {gen_workers}") + + + + + + # Just in case we want to announce it to the user - if len(js) == 1: - warnmsg = f"Text generated by {js[0]['server_name']}" + if len(generations) == 1: + warnmsg = f"Text generated by {[w[0] for w in gen_workers]}" emit('from_server', {'cmd': 'warnmsg', 'data': warnmsg}, broadcast=True) - genout = [cgen['text'] for cgen in js] + genout = [cgen['text'] for cgen in generations] for i in range(vars.numseqs): vars.lua_koboldbridge.outputs[i+1] = genout[i]