mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
129 lines
4.5 KiB
Python
129 lines
4.5 KiB
Python
from __future__ import annotations
|
|
|
|
import time
|
|
import json
|
|
import torch
|
|
import requests
|
|
import numpy as np
|
|
from typing import List, Optional, Union
|
|
import os
|
|
|
|
import utils
|
|
from logger import logger
|
|
|
|
from modeling.inference_model import (
|
|
GenerationResult,
|
|
GenerationSettings,
|
|
InferenceModel,
|
|
ModelCapabilities,
|
|
)
|
|
|
|
model_backend_name = "KoboldAI API"
|
|
|
|
class APIException(Exception):
|
|
"""To be used for errors when using the Kobold API as an interface."""
|
|
|
|
|
|
class model_backend(InferenceModel):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.base_url = ""
|
|
|
|
def is_valid(self, model_name, model_path, menu_path):
|
|
return model_name == "API"
|
|
|
|
def get_requested_parameters(self, model_name, model_path, menu_path):
|
|
if os.path.exists("settings/api.model_backend.settings") and 'base_url' not in vars(self):
|
|
with open("settings/api.model_backend.settings", "r") as f:
|
|
self.base_url = json.load(f)['base_url']
|
|
requested_parameters = []
|
|
requested_parameters.append({
|
|
"uitype": "text",
|
|
"unit": "text",
|
|
"label": "URL",
|
|
"id": "base_url",
|
|
"default": self.base_url,
|
|
"check": {"value": "", 'check': "!="},
|
|
"tooltip": "The URL of the KoboldAI API to connect to.",
|
|
"menu_path": "",
|
|
"extra_classes": "",
|
|
"refresh_model_inputs": False
|
|
})
|
|
return requested_parameters
|
|
|
|
def set_input_parameters(self, parameters):
|
|
self.base_url = parameters['base_url'].rstrip("/")
|
|
|
|
def _load(self, save_model: bool, initial_load: bool) -> None:
|
|
tokenizer_id = requests.get(f"{self.base_url}/api/v1/model").json()["result"]
|
|
|
|
self.tokenizer = self._get_tokenizer(tokenizer_id)
|
|
|
|
# Do not allow API to be served over the API
|
|
self.capabilties = ModelCapabilities(api_host=False)
|
|
|
|
def _save_settings(self):
|
|
with open("settings/api.model_backend.settings", "w") as f:
|
|
json.dump({"base_url": self.base_url}, f, indent="")
|
|
|
|
def _raw_generate(
|
|
self,
|
|
prompt_tokens: Union[List[int], torch.Tensor],
|
|
max_new: int,
|
|
gen_settings: GenerationSettings,
|
|
single_line: bool = False,
|
|
batch_count: int = 1,
|
|
seed: Optional[int] = None,
|
|
**kwargs,
|
|
):
|
|
|
|
if seed is not None:
|
|
logger.warning(
|
|
"Seed is unsupported on the APIInferenceModel. Seed will be ignored."
|
|
)
|
|
|
|
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
|
|
|
# Store context in memory to use it for comparison with generated content
|
|
utils.koboldai_vars.lastctx = decoded_prompt
|
|
|
|
# Build request JSON data
|
|
reqdata = {
|
|
"prompt": decoded_prompt,
|
|
"max_length": max_new,
|
|
"max_context_length": utils.koboldai_vars.max_length,
|
|
"rep_pen": gen_settings.rep_pen,
|
|
"rep_pen_slope": gen_settings.rep_pen_slope,
|
|
"rep_pen_range": gen_settings.rep_pen_range,
|
|
"temperature": gen_settings.temp,
|
|
"top_p": gen_settings.top_p,
|
|
"top_k": gen_settings.top_k,
|
|
"top_a": gen_settings.top_a,
|
|
"tfs": gen_settings.tfs,
|
|
"typical": gen_settings.typical,
|
|
"n": batch_count,
|
|
}
|
|
|
|
# Create request
|
|
while True:
|
|
req = requests.post(f"{self.base_url}/api/v1/generate", json=reqdata)
|
|
|
|
if req.status_code == 503:
|
|
# Server is currently generating something else so poll until it's our turn
|
|
time.sleep(1)
|
|
continue
|
|
|
|
js = req.json()
|
|
if req.status_code != 200:
|
|
logger.error(json.dumps(js, indent=4))
|
|
raise APIException(f"Bad API status code {req.status_code}")
|
|
|
|
genout = [obj["text"] for obj in js["results"]]
|
|
return GenerationResult(
|
|
model=self,
|
|
out_batches=np.array([self.tokenizer.encode(x) for x in genout]),
|
|
prompt=prompt_tokens,
|
|
is_whole_generation=True,
|
|
single_line=single_line,
|
|
)
|