diff --git a/aiserver.py b/aiserver.py index 8925b46d..81bdcf24 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1876,49 +1876,13 @@ def get_cluster_models(msg): # Get list of models from public cluster print("{0}Retrieving engine list...{1}".format(colors.PURPLE, colors.END), end="") try: - req = requests.get("{}/api/v1/models".format(url)) + req = requests.get(f"{url}/api/v2/status/models?type=text") except: logger.init_err("KAI Horde Models", status="Failed") logger.error("Provided KoboldAI Horde URL unreachable") emit('from_server', {'cmd': 'errmsg', 'data': "Provided KoboldAI Horde URL unreachable"}) return - if(req.status_code == 200): - engines = req.json() - print(engines) - try: - engines = [[en, en] for en in engines] - except: - print(engines) - raise - print(engines) - - online_model = "" - changed=False - - #Save the key - if not path.exists("settings"): - # If the client settings file doesn't exist, create it - # Write API key to file - os.makedirs('settings', exist_ok=True) - if path.exists(get_config_filename(model)): - with open(get_config_filename(model), "r") as file: - js = json.load(file) - if 'online_model' in js: - online_model = js['online_model'] - if "apikey" in js: - if js['apikey'] != koboldai_vars.oaiapikey: - changed=True - else: - changed=True - if changed: - js={} - with open(get_config_filename(model), "w") as file: - js["apikey"] = koboldai_vars.oaiapikey - file.write(json.dumps(js, indent=3)) - - emit('from_server', {'cmd': 'oai_engines', 'data': engines, 'online_model': online_model}, broadcast=True, room="UI_1") - emit('oai_engines', {'data': engines, 'online_model': online_model}, broadcast=False, room="UI_2") - else: + if not req.ok: # Something went wrong, print the message and quit since we can't initialize an engine logger.init_err("KAI Horde Models", status="Failed") logger.error(req.json()) @@ -1928,10 +1892,11 @@ def get_cluster_models(msg): engines = req.json() logger.debug(engines) try: - engines = [[en, en] for en in engines] + engines = [[en["name"], en["name"]] for en in engines] except: logger.error(engines) raise + logger.debug(engines) online_model = "" changed=False @@ -1956,10 +1921,12 @@ def get_cluster_models(msg): with open(get_config_filename(model), "w") as file: js["apikey"] = koboldai_vars.oaiapikey js["url"] = url - file.write(json.dumps(js, indent=3)) - + file.write(json.dumps(js, indent=3)) + logger.init_ok("KAI Horde Models", status="OK") - emit('from_server', {'cmd': 'oai_engines', 'data': engines, 'online_model': online_model}, broadcast=True) + + emit('from_server', {'cmd': 'oai_engines', 'data': engines, 'online_model': online_model}, broadcast=True, room="UI_1") + emit('oai_engines', {'data': engines, 'online_model': online_model}, broadcast=False, room="UI_2") # Function to patch transformers to use our soft prompt @@ -5965,12 +5932,15 @@ def cluster_raw_generate( 'trusted_workers': False, } - cluster_headers = {'apikey': koboldai_vars.apikey} - + client_agent = "KoboldAI:2.0.0:(discord)ebolam#1007" + cluster_headers = { + 'apikey': koboldai_vars.apikey, + "Client-Agent": client_agent + } try: # Create request req = requests.post( - koboldai_vars.colaburl[:-8] + "/api/v2/generate/async", + koboldai_vars.colaburl[:-8] + "/api/v2/generate/text/async", json=cluster_metadata, headers=cluster_headers ) @@ -6003,9 +5973,13 @@ def cluster_raw_generate( # We've sent the request and got the ID back, now we need to watch it to see when it finishes finished = False + cluster_agent_headers = { + "Client-Agent": client_agent + } + while not finished: - try: - req = requests.get(koboldai_vars.colaburl[:-8] + "/api/v1/generate/check/" + request_id) + try: + req = requests.get(koboldai_vars.colaburl[:-8] + "/api/v2/generate/text/status/" + request_id, headers=cluster_agent_headers) except requests.exceptions.ConnectionError: errmsg = f"Horde unavailable. Please try again later" logger.error(errmsg) @@ -6017,32 +5991,33 @@ def cluster_raw_generate( raise HordeException(errmsg) try: - js = req.json() + req_status = req.json() except requests.exceptions.JSONDecodeError: errmsg = f"Unexpected message received from the KoboldAI Horde: '{req.text}'" logger.error(errmsg) raise HordeException(errmsg) - if "done" not in js: + if "done" not in req_status: errmsg = f"Unexpected response received from the KoboldAI Horde: '{js}'" - logger.error(errmsg ) + logger.error(errmsg) raise HordeException(errmsg) - finished = js["done"] - koboldai_vars.horde_wait_time = js["wait_time"] - koboldai_vars.horde_queue_position = js["queue_position"] - koboldai_vars.horde_queue_size = js["waiting"] + finished = req_status["done"] + koboldai_vars.horde_wait_time = req_status["wait_time"] + koboldai_vars.horde_queue_position = req_status["queue_position"] + koboldai_vars.horde_queue_size = req_status["waiting"] if not finished: - logger.debug(js) + logger.debug(req_status) time.sleep(1) logger.debug("Last Horde Status Message: {}".format(js)) - js = requests.get(koboldai_vars.colaburl[:-8] + "/api/v1/generate/prompt/" + request_id).json()['generations'] - logger.debug("Horde Result: {}".format(js)) + if req_status["faulted"]: + raise HordeException("Horde Text generation faulted! Please try again") - gen_servers = [(cgen['server_name'],cgen['server_id']) for cgen in js] - logger.info(f"Generations by: {gen_servers}") + generations = req_status['generations'] + gen_workers = [(cgen['worker_name'],cgen['worker_id']) for cgen in generations] + logger.info(f"Generations by: {gen_workers}") # TODO: Fix this, using tpool so it's a context error # Just in case we want to announce it to the user @@ -6050,7 +6025,7 @@ def cluster_raw_generate( # warnmsg = f"Text generated by {js[0]['server_name']}" # emit('from_server', {'cmd': 'warnmsg', 'data': warnmsg}, broadcast=True) - return np.array([tokenizer.encode(cgen["text"]) for cgen in js]) + return np.array([tokenizer.encode(cgen["text"]) for cgen in generations]) def colab_raw_generate( prompt_tokens: List[int], @@ -9846,8 +9821,11 @@ def text2img_horde(prompt: str) -> Optional[Image.Image]: "height": 512 } } - - cluster_headers = {"apikey": koboldai_vars.sh_apikey or "0000000000"} + client_agent = "KoboldAI:2.0.0:(discord)ebolam#1007" + cluster_headers = { + 'apikey': koboldai_vars.apikey or "0000000000", + "Client-Agent": client_agent + } id_req = requests.post("https://stablehorde.net/api/v2/generate/async", json=final_submit_dict, headers=cluster_headers) if not id_req.ok: @@ -9895,11 +9873,14 @@ def text2img_horde(prompt: str) -> Optional[Image.Image]: if len(results["generations"]) > 1: logger.warning(f"Got too many generations, discarding extras. Got {len(results['generations'])}, expected 1.") - b64img = results["generations"][0]["img"] - base64_bytes = b64img.encode("utf-8") - img_bytes = base64.b64decode(base64_bytes) - img = Image.open(BytesIO(img_bytes)) - return img + imgurl = results["generations"][0]["img"] + try: + img_data = requests.get(imgurl, timeout=3).content + img = Image.open(BytesIO(img_data)) + return img + except Exception as err: + logger.error(f"Error retrieving image: {err}") + raise HordeException("Image fetching failed. See console for more details.") @logger.catch def text2img_api(prompt, art_guide="") -> Image.Image: