mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: And another refactor
This commit is contained in:
167
modeling/inference_models/horde.py
Normal file
167
modeling/inference_models/horde.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import torch
|
||||
import requests
|
||||
import numpy as np
|
||||
from typing import List, Union
|
||||
|
||||
import utils
|
||||
from logger import logger
|
||||
|
||||
from modeling.inference_model import (
|
||||
GenerationResult,
|
||||
GenerationSettings,
|
||||
InferenceModel,
|
||||
)
|
||||
|
||||
|
||||
class HordeException(Exception):
|
||||
"""To be used for errors on server side of the Horde."""
|
||||
|
||||
|
||||
class HordeInferenceModel(InferenceModel):
|
||||
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||
self.tokenizer = self._get_tokenizer(
|
||||
utils.koboldai_vars.cluster_requested_models[0]
|
||||
if len(utils.koboldai_vars.cluster_requested_models) > 0
|
||||
else "gpt2",
|
||||
)
|
||||
|
||||
def _raw_generate(
|
||||
self,
|
||||
prompt_tokens: Union[List[int], torch.Tensor],
|
||||
max_new: int,
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
) -> GenerationResult:
|
||||
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||
|
||||
# Store context in memory to use it for comparison with generated content
|
||||
utils.koboldai_vars.lastctx = decoded_prompt
|
||||
|
||||
# Build request JSON data
|
||||
reqdata = {
|
||||
"max_length": max_new,
|
||||
"max_context_length": utils.koboldai_vars.max_length,
|
||||
"rep_pen": gen_settings.rep_pen,
|
||||
"rep_pen_slope": gen_settings.rep_pen_slope,
|
||||
"rep_pen_range": gen_settings.rep_pen_range,
|
||||
"temperature": gen_settings.temp,
|
||||
"top_p": gen_settings.top_p,
|
||||
"top_k": int(gen_settings.top_k),
|
||||
"top_a": gen_settings.top_a,
|
||||
"tfs": gen_settings.tfs,
|
||||
"typical": gen_settings.typical,
|
||||
"n": batch_count,
|
||||
}
|
||||
|
||||
cluster_metadata = {
|
||||
"prompt": decoded_prompt,
|
||||
"params": reqdata,
|
||||
"models": [x for x in utils.koboldai_vars.cluster_requested_models if x],
|
||||
"trusted_workers": False,
|
||||
}
|
||||
|
||||
client_agent = "KoboldAI:2.0.0:koboldai.org"
|
||||
cluster_headers = {
|
||||
"apikey": utils.koboldai_vars.horde_api_key,
|
||||
"Client-Agent": client_agent,
|
||||
}
|
||||
|
||||
try:
|
||||
# Create request
|
||||
req = requests.post(
|
||||
utils.koboldai_vars.colaburl[:-8] + "/api/v2/generate/text/async",
|
||||
json=cluster_metadata,
|
||||
headers=cluster_headers,
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
errmsg = f"Horde unavailable. Please try again later"
|
||||
logger.error(errmsg)
|
||||
raise HordeException(errmsg)
|
||||
|
||||
if req.status_code == 503:
|
||||
errmsg = f"KoboldAI API Error: No available KoboldAI servers found in Horde to fulfil this request using the selected models or other properties."
|
||||
logger.error(errmsg)
|
||||
raise HordeException(errmsg)
|
||||
elif not req.ok:
|
||||
errmsg = f"KoboldAI API Error: Failed to get a standard reply from the Horde. Please check the console."
|
||||
logger.error(errmsg)
|
||||
logger.error(f"HTTP {req.status_code}!!!")
|
||||
logger.error(req.text)
|
||||
raise HordeException(errmsg)
|
||||
|
||||
try:
|
||||
req_status = req.json()
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
errmsg = f"Unexpected message received from the Horde: '{req.text}'"
|
||||
logger.error(errmsg)
|
||||
raise HordeException(errmsg)
|
||||
|
||||
request_id = req_status["id"]
|
||||
logger.debug("Horde Request ID: {}".format(request_id))
|
||||
|
||||
# We've sent the request and got the ID back, now we need to watch it to see when it finishes
|
||||
finished = False
|
||||
|
||||
cluster_agent_headers = {"Client-Agent": client_agent}
|
||||
|
||||
while not finished:
|
||||
try:
|
||||
req = requests.get(
|
||||
f"{utils.koboldai_vars.colaburl[:-8]}/api/v2/generate/text/status/{request_id}",
|
||||
headers=cluster_agent_headers,
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
errmsg = f"Horde unavailable. Please try again later"
|
||||
logger.error(errmsg)
|
||||
raise HordeException(errmsg)
|
||||
|
||||
if not req.ok:
|
||||
errmsg = f"KoboldAI API Error: Failed to get a standard reply from the Horde. Please check the console."
|
||||
logger.error(req.text)
|
||||
raise HordeException(errmsg)
|
||||
|
||||
try:
|
||||
req_status = req.json()
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
errmsg = (
|
||||
f"Unexpected message received from the KoboldAI Horde: '{req.text}'"
|
||||
)
|
||||
logger.error(errmsg)
|
||||
raise HordeException(errmsg)
|
||||
|
||||
if "done" not in req_status:
|
||||
errmsg = f"Unexpected response received from the KoboldAI Horde: '{req_status}'"
|
||||
logger.error(errmsg)
|
||||
raise HordeException(errmsg)
|
||||
|
||||
finished = req_status["done"]
|
||||
utils.koboldai_vars.horde_wait_time = req_status["wait_time"]
|
||||
utils.koboldai_vars.horde_queue_position = req_status["queue_position"]
|
||||
utils.koboldai_vars.horde_queue_size = req_status["waiting"]
|
||||
|
||||
if not finished:
|
||||
logger.debug(req_status)
|
||||
time.sleep(1)
|
||||
|
||||
logger.debug("Last Horde Status Message: {}".format(req_status))
|
||||
|
||||
if req_status["faulted"]:
|
||||
raise HordeException("Horde Text generation faulted! Please try again.")
|
||||
|
||||
generations = req_status["generations"]
|
||||
gen_servers = [(cgen["worker_name"], cgen["worker_id"]) for cgen in generations]
|
||||
logger.info(f"Generations by: {gen_servers}")
|
||||
|
||||
return GenerationResult(
|
||||
model=self,
|
||||
out_batches=np.array(
|
||||
[self.tokenizer.encode(cgen["text"]) for cgen in generations]
|
||||
),
|
||||
prompt=prompt_tokens,
|
||||
is_whole_generation=True,
|
||||
single_line=single_line,
|
||||
)
|
Reference in New Issue
Block a user