mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
194 lines
7.1 KiB
Python
194 lines
7.1 KiB
Python
import torch
|
|
import requests
|
|
import numpy as np
|
|
from typing import List, Optional, Union
|
|
import os
|
|
|
|
import utils
|
|
from logger import logger
|
|
from modeling.inference_model import (
|
|
GenerationResult,
|
|
GenerationSettings,
|
|
InferenceModel,
|
|
)
|
|
|
|
|
|
|
|
class OpenAIAPIError(Exception):
|
|
def __init__(self, error_type: str, error_message) -> None:
|
|
super().__init__(f"{error_type}: {error_message}")
|
|
|
|
|
|
class model_backend(InferenceModel):
|
|
"""InferenceModel for interfacing with OpenAI's generation API."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.key = ""
|
|
self.url = "https://api.goose.ai/v1/engines"
|
|
|
|
def is_valid(self, model_name, model_path, menu_path):
|
|
return model_name == "OAI" or model_name == "GooseAI"
|
|
|
|
def get_requested_parameters(self, model_name, model_path, menu_path):
|
|
if os.path.exists("settings/{}.model_backend.settings".format(self.source)) and 'colaburl' not in vars(self):
|
|
with open("settings/{}.model_backend.settings".format(self.source), "r") as f:
|
|
self.key = json.load(f)['key']
|
|
self.source = model_name
|
|
requested_parameters = []
|
|
requested_parameters.extend([{
|
|
"uitype": "text",
|
|
"unit": "text",
|
|
"label": "Key",
|
|
"id": "key",
|
|
"default": self.key,
|
|
"check": {"value": "", 'check': "!="},
|
|
"tooltip": "User Key to use when connecting to OpenAI/GooseAI.",
|
|
"menu_path": "",
|
|
"refresh_model_inputs": True,
|
|
"extra_classes": ""
|
|
},
|
|
{
|
|
"uitype": "dropdown",
|
|
"unit": "text",
|
|
"label": "Model",
|
|
"id": "model",
|
|
"default": "",
|
|
"check": {"value": "", 'check': "!="},
|
|
"tooltip": "Which model to use when running OpenAI/GooseAI.",
|
|
"menu_path": "",
|
|
"refresh_model_inputs": False,
|
|
"extra_classes": "",
|
|
'children': self.get_oai_models(),
|
|
|
|
}])
|
|
return requested_parameters
|
|
|
|
def set_input_parameters(self, parameters):
|
|
self.key = parameters['key'].strip()
|
|
self.model = parameters['model']
|
|
|
|
def get_oai_models(self):
|
|
if self.key == "":
|
|
return []
|
|
|
|
|
|
# Get list of models from OAI
|
|
logger.init("OAI Engines", status="Retrieving")
|
|
req = requests.get(
|
|
self.url,
|
|
headers = {
|
|
'Authorization': 'Bearer '+self.key
|
|
}
|
|
)
|
|
if(req.status_code == 200):
|
|
r = req.json()
|
|
engines = r["data"]
|
|
try:
|
|
engines = [{"value": en["id"], "text": "{} ({})".format(en['id'], "Ready" if en["ready"] == True else "Not Ready")} for en in engines]
|
|
except:
|
|
logger.error(engines)
|
|
raise
|
|
|
|
online_model = ""
|
|
|
|
|
|
logger.init_ok("OAI Engines", status="OK")
|
|
return engines
|
|
else:
|
|
# Something went wrong, print the message and quit since we can't initialize an engine
|
|
logger.init_err("OAI Engines", status="Failed")
|
|
logger.error(req.json())
|
|
emit('from_server', {'cmd': 'errmsg', 'data': req.json()})
|
|
return []
|
|
|
|
|
|
def _load(self, save_model: bool, initial_load: bool) -> None:
|
|
self.tokenizer = self._get_tokenizer("gpt2")
|
|
|
|
def _save_settings(self):
|
|
with open("settings/{}.model_backend.settings".format(self.source), "w") as f:
|
|
json.dump({"key": self.key}, f, indent="")
|
|
|
|
def _raw_generate(
|
|
self,
|
|
prompt_tokens: Union[List[int], torch.Tensor],
|
|
max_new: int,
|
|
gen_settings: GenerationSettings,
|
|
single_line: bool = False,
|
|
batch_count: int = 1,
|
|
seed: Optional[int] = None,
|
|
**kwargs,
|
|
) -> GenerationResult:
|
|
|
|
if seed is not None:
|
|
logger.warning(
|
|
"Seed is unsupported on the OpenAIAPIInferenceModel. Seed will be ignored."
|
|
)
|
|
|
|
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
|
|
|
# Store context in memory to use it for comparison with generated content
|
|
utils.koboldai_vars.lastctx = decoded_prompt
|
|
|
|
# Build request JSON data
|
|
# GooseAI is a subntype of OAI. So to check if it's this type, we check the configname as a workaround
|
|
# as the koboldai_vars.model will always be OAI
|
|
if "GooseAI" in utils.koboldai_vars.configname:
|
|
reqdata = {
|
|
"prompt": decoded_prompt,
|
|
"max_tokens": max_new,
|
|
"temperature": gen_settings.temp,
|
|
"top_a": gen_settings.top_a,
|
|
"top_p": gen_settings.top_p,
|
|
"top_k": gen_settings.top_k,
|
|
"tfs": gen_settings.tfs,
|
|
"typical_p": gen_settings.typical,
|
|
"repetition_penalty": gen_settings.rep_pen,
|
|
"repetition_penalty_slope": gen_settings.rep_pen_slope,
|
|
"repetition_penalty_range": gen_settings.rep_pen_range,
|
|
"n": batch_count,
|
|
# TODO: Implement streaming
|
|
"stream": False,
|
|
}
|
|
else:
|
|
reqdata = {
|
|
"prompt": decoded_prompt,
|
|
"max_tokens": max_new,
|
|
"temperature": gen_settings.temp,
|
|
"top_p": gen_settings.top_p,
|
|
"frequency_penalty": gen_settings.rep_pen,
|
|
"n": batch_count,
|
|
"stream": False,
|
|
}
|
|
|
|
req = requests.post(
|
|
self.url,
|
|
json=reqdata,
|
|
headers={
|
|
"Authorization": "Bearer " + self.key,
|
|
"Content-Type": "application/json",
|
|
},
|
|
)
|
|
|
|
j = req.json()
|
|
|
|
if not req.ok:
|
|
# Send error message to web client
|
|
if "error" in j:
|
|
error_type = j["error"]["type"]
|
|
error_message = j["error"]["message"]
|
|
else:
|
|
error_type = "Unknown"
|
|
error_message = "Unknown"
|
|
raise OpenAIAPIError(error_type, error_message)
|
|
|
|
outputs = [out["text"] for out in j["choices"]]
|
|
return GenerationResult(
|
|
model=self,
|
|
out_batches=np.array([self.tokenizer.encode(x) for x in outputs]),
|
|
prompt=prompt_tokens,
|
|
is_whole_generation=True,
|
|
single_line=single_line,
|
|
)
|