Model: And another refactor

This commit is contained in:
somebody
2023-03-01 19:16:35 -06:00
parent 225dcf1a0a
commit 54cecd4d5d
18 changed files with 3045 additions and 2911 deletions

View File

@@ -0,0 +1,98 @@
import torch
import requests
import numpy as np
from typing import List, Union
import utils
from modeling.inference_model import (
GenerationResult,
GenerationSettings,
InferenceModel,
)
class OpenAIAPIError(Exception):
def __init__(self, error_type: str, error_message) -> None:
super().__init__(f"{error_type}: {error_message}")
class OpenAIAPIInferenceModel(InferenceModel):
"""InferenceModel for interfacing with OpenAI's generation API."""
def _load(self, save_model: bool, initial_load: bool) -> None:
self.tokenizer = self._get_tokenizer("gpt2")
def _raw_generate(
self,
prompt_tokens: Union[List[int], torch.Tensor],
max_new: int,
gen_settings: GenerationSettings,
single_line: bool = False,
batch_count: int = 1,
) -> GenerationResult:
# Taken mainly from oairequest()
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
# Store context in memory to use it for comparison with generated content
utils.koboldai_vars.lastctx = decoded_prompt
# Build request JSON data
# GooseAI is a subntype of OAI. So to check if it's this type, we check the configname as a workaround
# as the koboldai_vars.model will always be OAI
if "GooseAI" in utils.koboldai_vars.configname:
reqdata = {
"prompt": decoded_prompt,
"max_tokens": max_new,
"temperature": gen_settings.temp,
"top_a": gen_settings.top_a,
"top_p": gen_settings.top_p,
"top_k": gen_settings.top_k,
"tfs": gen_settings.tfs,
"typical_p": gen_settings.typical,
"repetition_penalty": gen_settings.rep_pen,
"repetition_penalty_slope": gen_settings.rep_pen_slope,
"repetition_penalty_range": gen_settings.rep_pen_range,
"n": batch_count,
# TODO: Implement streaming
"stream": False,
}
else:
reqdata = {
"prompt": decoded_prompt,
"max_tokens": max_new,
"temperature": gen_settings.temp,
"top_p": gen_settings.top_p,
"frequency_penalty": gen_settings.rep_pen,
"n": batch_count,
"stream": False,
}
req = requests.post(
utils.koboldai_vars.oaiurl,
json=reqdata,
headers={
"Authorization": "Bearer " + utils.koboldai_vars.oaiapikey,
"Content-Type": "application/json",
},
)
j = req.json()
if not req.ok:
# Send error message to web client
if "error" in j:
error_type = j["error"]["type"]
error_message = j["error"]["message"]
else:
error_type = "Unknown"
error_message = "Unknown"
raise OpenAIAPIError(error_type, error_message)
outputs = [out["text"] for out in j["choices"]]
return GenerationResult(
model=self,
out_batches=np.array([self.tokenizer.encode(x) for x in outputs]),
prompt=prompt_tokens,
is_whole_generation=True,
single_line=single_line,
)