This commit is contained in:
Divided by Zer0 2023-02-23 18:27:11 +01:00
parent c27faf56e6
commit 2de9672b95
1 changed files with 80 additions and 10 deletions

View File

@ -1671,7 +1671,7 @@ def get_cluster_models(msg):
# Get list of models from public cluster # Get list of models from public cluster
logger.init("KAI Horde Models", status="Retrieving") logger.init("KAI Horde Models", status="Retrieving")
try: try:
req = requests.get("{}/api/v1/models".format(url)) req = requests.get(f"{url}/api/v2/status/models?type=text")
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
logger.init_err("KAI Horde Models", status="Failed") logger.init_err("KAI Horde Models", status="Failed")
logger.error("Provided KoboldAI Horde URL unreachable") logger.error("Provided KoboldAI Horde URL unreachable")
@ -1687,10 +1687,11 @@ def get_cluster_models(msg):
engines = req.json() engines = req.json()
logger.debug(engines) logger.debug(engines)
try: try:
engines = [[en, en] for en in engines] engines = [[en["name"], en["name"]] for en in engines]
except: except:
logger.error(engines) logger.error(engines)
raise raise
logger.debug(engines)
online_model = "" online_model = ""
changed=False changed=False
@ -5269,15 +5270,21 @@ def sendtocluster(txt, min, max):
cluster_metadata = { cluster_metadata = {
'prompt': txt, 'prompt': txt,
'params': reqdata, 'params': reqdata,
'api_key': vars.apikey,
'models': vars.cluster_requested_models, 'models': vars.cluster_requested_models,
} 'trusted_workers': False,
}
client_agent = "KoboldAI:1.19.3:(discord)Henky!!#2205"
cluster_headers = {
'apikey': koboldai_vars.apikey,
"Client-Agent": client_agent
}
logger.debug(f"Horde Payload: {cluster_metadata}") logger.debug(f"Horde Payload: {cluster_metadata}")
try: try:
# Create request # Create request
req = requests.post( req = requests.post(
vars.colaburl[:-8] + "/api/v1/generate/sync", vars.colaburl[:-8] + "/api/v2/generate/text/async",
json=cluster_metadata, json=cluster_metadata,
headers=cluster_headers,
) )
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
errmsg = f"Horde unavailable. Please try again later" errmsg = f"Horde unavailable. Please try again later"
@ -5305,13 +5312,76 @@ def sendtocluster(txt, min, max):
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True) emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
set_aibusy(0) set_aibusy(0)
return return
gen_servers = [(cgen['server_name'],cgen['server_id']) for cgen in js]
logger.info(f"Generations by: {gen_servers}") request_id = js["id"]
logger.debug("Horde Request ID: {}".format(request_id))
cluster_agent_headers = {
"Client-Agent": client_agent
}
finished = False
while not finished:
try:
req = requests.get(koboldai_vars.colaburl[:-8] + "/api/v2/generate/text/status/" + request_id, headers=cluster_agent_headers)
except requests.exceptions.ConnectionError:
errmsg = f"Horde unavailable. Please try again later"
logger.error(errmsg)
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
set_aibusy(0)
return
if not req.ok:
errmsg = f"KoboldAI API Error: Failed to get a standard reply from the Horde. Please check the console."
logger.error(req.text)
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
set_aibusy(0)
return
try:
req_status = req.json()
except requests.exceptions.JSONDecodeError:
errmsg = f"Unexpected message received from the KoboldAI Horde: '{req.text}'"
logger.error(errmsg)
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
set_aibusy(0)
return
if "done" not in req_status:
errmsg = f"Unexpected response received from the KoboldAI Horde: '{js}'"
logger.error(errmsg)
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
set_aibusy(0)
return
finished = req_status["done"]
if not finished:
logger.debug(req_status)
time.sleep(1)
logger.debug("Last Horde Status Message: {}".format(js))
if req_status["faulted"]:
errmsg = "Horde Text generation faulted! Please try again"
logger.error(errmsg)
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
set_aibusy(0)
return
generations = req_status['generations']
gen_workers = [(cgen['worker_name'],cgen['worker_id']) for cgen in generations]
logger.info(f"Generations by: {gen_workers}")
# Just in case we want to announce it to the user # Just in case we want to announce it to the user
if len(js) == 1: if len(generations) == 1:
warnmsg = f"Text generated by {js[0]['server_name']}" warnmsg = f"Text generated by {[w[0] for w in gen_workers]}"
emit('from_server', {'cmd': 'warnmsg', 'data': warnmsg}, broadcast=True) emit('from_server', {'cmd': 'warnmsg', 'data': warnmsg}, broadcast=True)
genout = [cgen['text'] for cgen in js] genout = [cgen['text'] for cgen in generations]
for i in range(vars.numseqs): for i in range(vars.numseqs):
vars.lua_koboldbridge.outputs[i+1] = genout[i] vars.lua_koboldbridge.outputs[i+1] = genout[i]