mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge pull request #294 from db0/kaimerge
Changes to work with the merged hordes
This commit is contained in:
115
aiserver.py
115
aiserver.py
@@ -1876,49 +1876,13 @@ def get_cluster_models(msg):
|
||||
# Get list of models from public cluster
|
||||
print("{0}Retrieving engine list...{1}".format(colors.PURPLE, colors.END), end="")
|
||||
try:
|
||||
req = requests.get("{}/api/v1/models".format(url))
|
||||
req = requests.get(f"{url}/api/v2/status/models?type=text")
|
||||
except:
|
||||
logger.init_err("KAI Horde Models", status="Failed")
|
||||
logger.error("Provided KoboldAI Horde URL unreachable")
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': "Provided KoboldAI Horde URL unreachable"})
|
||||
return
|
||||
if(req.status_code == 200):
|
||||
engines = req.json()
|
||||
print(engines)
|
||||
try:
|
||||
engines = [[en, en] for en in engines]
|
||||
except:
|
||||
print(engines)
|
||||
raise
|
||||
print(engines)
|
||||
|
||||
online_model = ""
|
||||
changed=False
|
||||
|
||||
#Save the key
|
||||
if not path.exists("settings"):
|
||||
# If the client settings file doesn't exist, create it
|
||||
# Write API key to file
|
||||
os.makedirs('settings', exist_ok=True)
|
||||
if path.exists(get_config_filename(model)):
|
||||
with open(get_config_filename(model), "r") as file:
|
||||
js = json.load(file)
|
||||
if 'online_model' in js:
|
||||
online_model = js['online_model']
|
||||
if "apikey" in js:
|
||||
if js['apikey'] != koboldai_vars.oaiapikey:
|
||||
changed=True
|
||||
else:
|
||||
changed=True
|
||||
if changed:
|
||||
js={}
|
||||
with open(get_config_filename(model), "w") as file:
|
||||
js["apikey"] = koboldai_vars.oaiapikey
|
||||
file.write(json.dumps(js, indent=3))
|
||||
|
||||
emit('from_server', {'cmd': 'oai_engines', 'data': engines, 'online_model': online_model}, broadcast=True, room="UI_1")
|
||||
emit('oai_engines', {'data': engines, 'online_model': online_model}, broadcast=False, room="UI_2")
|
||||
else:
|
||||
if not req.ok:
|
||||
# Something went wrong, print the message and quit since we can't initialize an engine
|
||||
logger.init_err("KAI Horde Models", status="Failed")
|
||||
logger.error(req.json())
|
||||
@@ -1928,10 +1892,11 @@ def get_cluster_models(msg):
|
||||
engines = req.json()
|
||||
logger.debug(engines)
|
||||
try:
|
||||
engines = [[en, en] for en in engines]
|
||||
engines = [[en["name"], en["name"]] for en in engines]
|
||||
except:
|
||||
logger.error(engines)
|
||||
raise
|
||||
logger.debug(engines)
|
||||
|
||||
online_model = ""
|
||||
changed=False
|
||||
@@ -1956,10 +1921,12 @@ def get_cluster_models(msg):
|
||||
with open(get_config_filename(model), "w") as file:
|
||||
js["apikey"] = koboldai_vars.oaiapikey
|
||||
js["url"] = url
|
||||
file.write(json.dumps(js, indent=3))
|
||||
|
||||
file.write(json.dumps(js, indent=3))
|
||||
|
||||
logger.init_ok("KAI Horde Models", status="OK")
|
||||
emit('from_server', {'cmd': 'oai_engines', 'data': engines, 'online_model': online_model}, broadcast=True)
|
||||
|
||||
emit('from_server', {'cmd': 'oai_engines', 'data': engines, 'online_model': online_model}, broadcast=True, room="UI_1")
|
||||
emit('oai_engines', {'data': engines, 'online_model': online_model}, broadcast=False, room="UI_2")
|
||||
|
||||
|
||||
# Function to patch transformers to use our soft prompt
|
||||
@@ -5965,12 +5932,15 @@ def cluster_raw_generate(
|
||||
'trusted_workers': False,
|
||||
}
|
||||
|
||||
cluster_headers = {'apikey': koboldai_vars.apikey}
|
||||
|
||||
client_agent = "KoboldAI:2.0.0:(discord)ebolam#1007"
|
||||
cluster_headers = {
|
||||
'apikey': koboldai_vars.apikey,
|
||||
"Client-Agent": client_agent
|
||||
}
|
||||
try:
|
||||
# Create request
|
||||
req = requests.post(
|
||||
koboldai_vars.colaburl[:-8] + "/api/v2/generate/async",
|
||||
koboldai_vars.colaburl[:-8] + "/api/v2/generate/text/async",
|
||||
json=cluster_metadata,
|
||||
headers=cluster_headers
|
||||
)
|
||||
@@ -6003,9 +5973,13 @@ def cluster_raw_generate(
|
||||
# We've sent the request and got the ID back, now we need to watch it to see when it finishes
|
||||
finished = False
|
||||
|
||||
cluster_agent_headers = {
|
||||
"Client-Agent": client_agent
|
||||
}
|
||||
|
||||
while not finished:
|
||||
try:
|
||||
req = requests.get(koboldai_vars.colaburl[:-8] + "/api/v1/generate/check/" + request_id)
|
||||
try:
|
||||
req = requests.get(koboldai_vars.colaburl[:-8] + "/api/v2/generate/text/status/" + request_id, headers=cluster_agent_headers)
|
||||
except requests.exceptions.ConnectionError:
|
||||
errmsg = f"Horde unavailable. Please try again later"
|
||||
logger.error(errmsg)
|
||||
@@ -6017,32 +5991,33 @@ def cluster_raw_generate(
|
||||
raise HordeException(errmsg)
|
||||
|
||||
try:
|
||||
js = req.json()
|
||||
req_status = req.json()
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
errmsg = f"Unexpected message received from the KoboldAI Horde: '{req.text}'"
|
||||
logger.error(errmsg)
|
||||
raise HordeException(errmsg)
|
||||
|
||||
if "done" not in js:
|
||||
if "done" not in req_status:
|
||||
errmsg = f"Unexpected response received from the KoboldAI Horde: '{js}'"
|
||||
logger.error(errmsg )
|
||||
logger.error(errmsg)
|
||||
raise HordeException(errmsg)
|
||||
|
||||
finished = js["done"]
|
||||
koboldai_vars.horde_wait_time = js["wait_time"]
|
||||
koboldai_vars.horde_queue_position = js["queue_position"]
|
||||
koboldai_vars.horde_queue_size = js["waiting"]
|
||||
finished = req_status["done"]
|
||||
koboldai_vars.horde_wait_time = req_status["wait_time"]
|
||||
koboldai_vars.horde_queue_position = req_status["queue_position"]
|
||||
koboldai_vars.horde_queue_size = req_status["waiting"]
|
||||
|
||||
if not finished:
|
||||
logger.debug(js)
|
||||
logger.debug(req_status)
|
||||
time.sleep(1)
|
||||
|
||||
logger.debug("Last Horde Status Message: {}".format(js))
|
||||
js = requests.get(koboldai_vars.colaburl[:-8] + "/api/v1/generate/prompt/" + request_id).json()['generations']
|
||||
logger.debug("Horde Result: {}".format(js))
|
||||
if req_status["faulted"]:
|
||||
raise HordeException("Horde Text generation faulted! Please try again")
|
||||
|
||||
gen_servers = [(cgen['server_name'],cgen['server_id']) for cgen in js]
|
||||
logger.info(f"Generations by: {gen_servers}")
|
||||
generations = req_status['generations']
|
||||
gen_workers = [(cgen['worker_name'],cgen['worker_id']) for cgen in generations]
|
||||
logger.info(f"Generations by: {gen_workers}")
|
||||
|
||||
# TODO: Fix this, using tpool so it's a context error
|
||||
# Just in case we want to announce it to the user
|
||||
@@ -6050,7 +6025,7 @@ def cluster_raw_generate(
|
||||
# warnmsg = f"Text generated by {js[0]['server_name']}"
|
||||
# emit('from_server', {'cmd': 'warnmsg', 'data': warnmsg}, broadcast=True)
|
||||
|
||||
return np.array([tokenizer.encode(cgen["text"]) for cgen in js])
|
||||
return np.array([tokenizer.encode(cgen["text"]) for cgen in generations])
|
||||
|
||||
def colab_raw_generate(
|
||||
prompt_tokens: List[int],
|
||||
@@ -9846,8 +9821,11 @@ def text2img_horde(prompt: str) -> Optional[Image.Image]:
|
||||
"height": 512
|
||||
}
|
||||
}
|
||||
|
||||
cluster_headers = {"apikey": koboldai_vars.sh_apikey or "0000000000"}
|
||||
client_agent = "KoboldAI:2.0.0:(discord)ebolam#1007"
|
||||
cluster_headers = {
|
||||
'apikey': koboldai_vars.apikey or "0000000000",
|
||||
"Client-Agent": client_agent
|
||||
}
|
||||
id_req = requests.post("https://stablehorde.net/api/v2/generate/async", json=final_submit_dict, headers=cluster_headers)
|
||||
|
||||
if not id_req.ok:
|
||||
@@ -9895,11 +9873,14 @@ def text2img_horde(prompt: str) -> Optional[Image.Image]:
|
||||
if len(results["generations"]) > 1:
|
||||
logger.warning(f"Got too many generations, discarding extras. Got {len(results['generations'])}, expected 1.")
|
||||
|
||||
b64img = results["generations"][0]["img"]
|
||||
base64_bytes = b64img.encode("utf-8")
|
||||
img_bytes = base64.b64decode(base64_bytes)
|
||||
img = Image.open(BytesIO(img_bytes))
|
||||
return img
|
||||
imgurl = results["generations"][0]["img"]
|
||||
try:
|
||||
img_data = requests.get(imgurl, timeout=3).content
|
||||
img = Image.open(BytesIO(img_data))
|
||||
return img
|
||||
except Exception as err:
|
||||
logger.error(f"Error retrieving image: {err}")
|
||||
raise HordeException("Image fetching failed. See console for more details.")
|
||||
|
||||
@logger.catch
|
||||
def text2img_api(prompt, art_guide="") -> Image.Image:
|
||||
|
Reference in New Issue
Block a user