Merge pull request #294 from db0/kaimerge

Changes to work with the merged hordes
This commit is contained in:
henk717
2023-02-26 16:00:48 +01:00
committed by GitHub

View File

@@ -1876,49 +1876,13 @@ def get_cluster_models(msg):
# Get list of models from public cluster
print("{0}Retrieving engine list...{1}".format(colors.PURPLE, colors.END), end="")
try:
req = requests.get("{}/api/v1/models".format(url))
req = requests.get(f"{url}/api/v2/status/models?type=text")
except:
logger.init_err("KAI Horde Models", status="Failed")
logger.error("Provided KoboldAI Horde URL unreachable")
emit('from_server', {'cmd': 'errmsg', 'data': "Provided KoboldAI Horde URL unreachable"})
return
if(req.status_code == 200):
engines = req.json()
print(engines)
try:
engines = [[en, en] for en in engines]
except:
print(engines)
raise
print(engines)
online_model = ""
changed=False
#Save the key
if not path.exists("settings"):
# If the client settings file doesn't exist, create it
# Write API key to file
os.makedirs('settings', exist_ok=True)
if path.exists(get_config_filename(model)):
with open(get_config_filename(model), "r") as file:
js = json.load(file)
if 'online_model' in js:
online_model = js['online_model']
if "apikey" in js:
if js['apikey'] != koboldai_vars.oaiapikey:
changed=True
else:
changed=True
if changed:
js={}
with open(get_config_filename(model), "w") as file:
js["apikey"] = koboldai_vars.oaiapikey
file.write(json.dumps(js, indent=3))
emit('from_server', {'cmd': 'oai_engines', 'data': engines, 'online_model': online_model}, broadcast=True, room="UI_1")
emit('oai_engines', {'data': engines, 'online_model': online_model}, broadcast=False, room="UI_2")
else:
if not req.ok:
# Something went wrong, print the message and quit since we can't initialize an engine
logger.init_err("KAI Horde Models", status="Failed")
logger.error(req.json())
@@ -1928,10 +1892,11 @@ def get_cluster_models(msg):
engines = req.json()
logger.debug(engines)
try:
engines = [[en, en] for en in engines]
engines = [[en["name"], en["name"]] for en in engines]
except:
logger.error(engines)
raise
logger.debug(engines)
online_model = ""
changed=False
@@ -1956,10 +1921,12 @@ def get_cluster_models(msg):
with open(get_config_filename(model), "w") as file:
js["apikey"] = koboldai_vars.oaiapikey
js["url"] = url
file.write(json.dumps(js, indent=3))
file.write(json.dumps(js, indent=3))
logger.init_ok("KAI Horde Models", status="OK")
emit('from_server', {'cmd': 'oai_engines', 'data': engines, 'online_model': online_model}, broadcast=True)
emit('from_server', {'cmd': 'oai_engines', 'data': engines, 'online_model': online_model}, broadcast=True, room="UI_1")
emit('oai_engines', {'data': engines, 'online_model': online_model}, broadcast=False, room="UI_2")
# Function to patch transformers to use our soft prompt
@@ -5965,12 +5932,15 @@ def cluster_raw_generate(
'trusted_workers': False,
}
cluster_headers = {'apikey': koboldai_vars.apikey}
client_agent = "KoboldAI:2.0.0:(discord)ebolam#1007"
cluster_headers = {
'apikey': koboldai_vars.apikey,
"Client-Agent": client_agent
}
try:
# Create request
req = requests.post(
koboldai_vars.colaburl[:-8] + "/api/v2/generate/async",
koboldai_vars.colaburl[:-8] + "/api/v2/generate/text/async",
json=cluster_metadata,
headers=cluster_headers
)
@@ -6003,9 +5973,13 @@ def cluster_raw_generate(
# We've sent the request and got the ID back, now we need to watch it to see when it finishes
finished = False
cluster_agent_headers = {
"Client-Agent": client_agent
}
while not finished:
try:
req = requests.get(koboldai_vars.colaburl[:-8] + "/api/v1/generate/check/" + request_id)
try:
req = requests.get(koboldai_vars.colaburl[:-8] + "/api/v2/generate/text/status/" + request_id, headers=cluster_agent_headers)
except requests.exceptions.ConnectionError:
errmsg = f"Horde unavailable. Please try again later"
logger.error(errmsg)
@@ -6017,32 +5991,33 @@ def cluster_raw_generate(
raise HordeException(errmsg)
try:
js = req.json()
req_status = req.json()
except requests.exceptions.JSONDecodeError:
errmsg = f"Unexpected message received from the KoboldAI Horde: '{req.text}'"
logger.error(errmsg)
raise HordeException(errmsg)
if "done" not in js:
if "done" not in req_status:
errmsg = f"Unexpected response received from the KoboldAI Horde: '{js}'"
logger.error(errmsg )
logger.error(errmsg)
raise HordeException(errmsg)
finished = js["done"]
koboldai_vars.horde_wait_time = js["wait_time"]
koboldai_vars.horde_queue_position = js["queue_position"]
koboldai_vars.horde_queue_size = js["waiting"]
finished = req_status["done"]
koboldai_vars.horde_wait_time = req_status["wait_time"]
koboldai_vars.horde_queue_position = req_status["queue_position"]
koboldai_vars.horde_queue_size = req_status["waiting"]
if not finished:
logger.debug(js)
logger.debug(req_status)
time.sleep(1)
logger.debug("Last Horde Status Message: {}".format(js))
js = requests.get(koboldai_vars.colaburl[:-8] + "/api/v1/generate/prompt/" + request_id).json()['generations']
logger.debug("Horde Result: {}".format(js))
if req_status["faulted"]:
raise HordeException("Horde Text generation faulted! Please try again")
gen_servers = [(cgen['server_name'],cgen['server_id']) for cgen in js]
logger.info(f"Generations by: {gen_servers}")
generations = req_status['generations']
gen_workers = [(cgen['worker_name'],cgen['worker_id']) for cgen in generations]
logger.info(f"Generations by: {gen_workers}")
# TODO: Fix this, using tpool so it's a context error
# Just in case we want to announce it to the user
@@ -6050,7 +6025,7 @@ def cluster_raw_generate(
# warnmsg = f"Text generated by {js[0]['server_name']}"
# emit('from_server', {'cmd': 'warnmsg', 'data': warnmsg}, broadcast=True)
return np.array([tokenizer.encode(cgen["text"]) for cgen in js])
return np.array([tokenizer.encode(cgen["text"]) for cgen in generations])
def colab_raw_generate(
prompt_tokens: List[int],
@@ -9846,8 +9821,11 @@ def text2img_horde(prompt: str) -> Optional[Image.Image]:
"height": 512
}
}
cluster_headers = {"apikey": koboldai_vars.sh_apikey or "0000000000"}
client_agent = "KoboldAI:2.0.0:(discord)ebolam#1007"
cluster_headers = {
'apikey': koboldai_vars.apikey or "0000000000",
"Client-Agent": client_agent
}
id_req = requests.post("https://stablehorde.net/api/v2/generate/async", json=final_submit_dict, headers=cluster_headers)
if not id_req.ok:
@@ -9895,11 +9873,14 @@ def text2img_horde(prompt: str) -> Optional[Image.Image]:
if len(results["generations"]) > 1:
logger.warning(f"Got too many generations, discarding extras. Got {len(results['generations'])}, expected 1.")
b64img = results["generations"][0]["img"]
base64_bytes = b64img.encode("utf-8")
img_bytes = base64.b64decode(base64_bytes)
img = Image.open(BytesIO(img_bytes))
return img
imgurl = results["generations"][0]["img"]
try:
img_data = requests.get(imgurl, timeout=3).content
img = Image.open(BytesIO(img_data))
return img
except Exception as err:
logger.error(f"Error retrieving image: {err}")
raise HordeException("Image fetching failed. See console for more details.")
@logger.catch
def text2img_api(prompt, art_guide="") -> Image.Image: