mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: And another refactor
This commit is contained in:
98
modeling/inference_models/openai.py
Normal file
98
modeling/inference_models/openai.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import torch
|
||||
import requests
|
||||
import numpy as np
|
||||
from typing import List, Union
|
||||
|
||||
import utils
|
||||
from modeling.inference_model import (
|
||||
GenerationResult,
|
||||
GenerationSettings,
|
||||
InferenceModel,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIAPIError(Exception):
|
||||
def __init__(self, error_type: str, error_message) -> None:
|
||||
super().__init__(f"{error_type}: {error_message}")
|
||||
|
||||
|
||||
class OpenAIAPIInferenceModel(InferenceModel):
|
||||
"""InferenceModel for interfacing with OpenAI's generation API."""
|
||||
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||
self.tokenizer = self._get_tokenizer("gpt2")
|
||||
|
||||
def _raw_generate(
|
||||
self,
|
||||
prompt_tokens: Union[List[int], torch.Tensor],
|
||||
max_new: int,
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
) -> GenerationResult:
|
||||
# Taken mainly from oairequest()
|
||||
|
||||
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||
|
||||
# Store context in memory to use it for comparison with generated content
|
||||
utils.koboldai_vars.lastctx = decoded_prompt
|
||||
|
||||
# Build request JSON data
|
||||
# GooseAI is a subntype of OAI. So to check if it's this type, we check the configname as a workaround
|
||||
# as the koboldai_vars.model will always be OAI
|
||||
if "GooseAI" in utils.koboldai_vars.configname:
|
||||
reqdata = {
|
||||
"prompt": decoded_prompt,
|
||||
"max_tokens": max_new,
|
||||
"temperature": gen_settings.temp,
|
||||
"top_a": gen_settings.top_a,
|
||||
"top_p": gen_settings.top_p,
|
||||
"top_k": gen_settings.top_k,
|
||||
"tfs": gen_settings.tfs,
|
||||
"typical_p": gen_settings.typical,
|
||||
"repetition_penalty": gen_settings.rep_pen,
|
||||
"repetition_penalty_slope": gen_settings.rep_pen_slope,
|
||||
"repetition_penalty_range": gen_settings.rep_pen_range,
|
||||
"n": batch_count,
|
||||
# TODO: Implement streaming
|
||||
"stream": False,
|
||||
}
|
||||
else:
|
||||
reqdata = {
|
||||
"prompt": decoded_prompt,
|
||||
"max_tokens": max_new,
|
||||
"temperature": gen_settings.temp,
|
||||
"top_p": gen_settings.top_p,
|
||||
"frequency_penalty": gen_settings.rep_pen,
|
||||
"n": batch_count,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
req = requests.post(
|
||||
utils.koboldai_vars.oaiurl,
|
||||
json=reqdata,
|
||||
headers={
|
||||
"Authorization": "Bearer " + utils.koboldai_vars.oaiapikey,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
j = req.json()
|
||||
|
||||
if not req.ok:
|
||||
# Send error message to web client
|
||||
if "error" in j:
|
||||
error_type = j["error"]["type"]
|
||||
error_message = j["error"]["message"]
|
||||
else:
|
||||
error_type = "Unknown"
|
||||
error_message = "Unknown"
|
||||
raise OpenAIAPIError(error_type, error_message)
|
||||
|
||||
outputs = [out["text"] for out in j["choices"]]
|
||||
return GenerationResult(
|
||||
model=self,
|
||||
out_batches=np.array([self.tokenizer.encode(x) for x in outputs]),
|
||||
prompt=prompt_tokens,
|
||||
is_whole_generation=True,
|
||||
single_line=single_line,
|
||||
)
|
Reference in New Issue
Block a user