mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
79 lines
2.3 KiB
Python
79 lines
2.3 KiB
Python
from __future__ import annotations
|
|
|
|
import torch
|
|
import requests
|
|
import numpy as np
|
|
from typing import List, Union
|
|
|
|
import utils
|
|
|
|
from modeling.inference_model import (
|
|
GenerationResult,
|
|
GenerationSettings,
|
|
InferenceModel,
|
|
)
|
|
|
|
|
|
class ColabException(Exception):
|
|
"""To be used for errors when using the Colab API as an interface."""
|
|
|
|
|
|
class ColabInferenceModel(InferenceModel):
|
|
def _load(self, save_model: bool, initial_load: bool) -> None:
|
|
self.tokenizer = self._get_tokenizer("EleutherAI/gpt-neo-2.7B")
|
|
|
|
def _raw_generate(
|
|
self,
|
|
prompt_tokens: Union[List[int], torch.Tensor],
|
|
max_new: int,
|
|
gen_settings: GenerationSettings,
|
|
single_line: bool = False,
|
|
batch_count: int = 1,
|
|
**kwargs
|
|
):
|
|
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
|
|
|
# Store context in memory to use it for comparison with generated content
|
|
utils.koboldai_vars.lastctx = decoded_prompt
|
|
|
|
# Build request JSON data
|
|
reqdata = {
|
|
"text": decoded_prompt,
|
|
"min": 0,
|
|
"max": max_new,
|
|
"rep_pen": gen_settings.rep_pen,
|
|
"rep_pen_slope": gen_settings.rep_pen_slope,
|
|
"rep_pen_range": gen_settings.rep_pen_range,
|
|
"temperature": gen_settings.temp,
|
|
"top_p": gen_settings.top_p,
|
|
"top_k": gen_settings.top_k,
|
|
"tfs": gen_settings.tfs,
|
|
"typical": gen_settings.typical,
|
|
"topa": gen_settings.top_a,
|
|
"numseqs": batch_count,
|
|
"retfultxt": False,
|
|
}
|
|
|
|
# Create request
|
|
req = requests.post(utils.koboldai_vars.colaburl, json=reqdata)
|
|
|
|
if req.status_code != 200:
|
|
raise ColabException(f"Bad status code {req.status_code}")
|
|
|
|
# Deal with the response
|
|
js = req.json()["data"]
|
|
|
|
# Try to be backwards compatible with outdated colab
|
|
if "text" in js:
|
|
genout = [utils.getnewcontent(js["text"], self.tokenizer)]
|
|
else:
|
|
genout = js["seqs"]
|
|
|
|
return GenerationResult(
|
|
model=self,
|
|
out_batches=np.array([self.tokenizer.encode(x) for x in genout]),
|
|
prompt=prompt_tokens,
|
|
is_whole_generation=True,
|
|
single_line=single_line,
|
|
)
|