from __future__ import annotations import torch import requests import numpy as np from typing import List, Union import utils from modeling.inference_model import ( GenerationResult, GenerationSettings, InferenceModel, ) class ColabException(Exception): """To be used for errors when using the Colab API as an interface.""" class ColabInferenceModel(InferenceModel): def _load(self, save_model: bool, initial_load: bool) -> None: self.tokenizer = self._get_tokenizer("EleutherAI/gpt-neo-2.7B") def _raw_generate( self, prompt_tokens: Union[List[int], torch.Tensor], max_new: int, gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, **kwargs ): decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens)) # Store context in memory to use it for comparison with generated content utils.koboldai_vars.lastctx = decoded_prompt # Build request JSON data reqdata = { "text": decoded_prompt, "min": 0, "max": max_new, "rep_pen": gen_settings.rep_pen, "rep_pen_slope": gen_settings.rep_pen_slope, "rep_pen_range": gen_settings.rep_pen_range, "temperature": gen_settings.temp, "top_p": gen_settings.top_p, "top_k": gen_settings.top_k, "tfs": gen_settings.tfs, "typical": gen_settings.typical, "topa": gen_settings.top_a, "numseqs": batch_count, "retfultxt": False, } # Create request req = requests.post(utils.koboldai_vars.colaburl, json=reqdata) if req.status_code != 200: raise ColabException(f"Bad status code {req.status_code}") # Deal with the response js = req.json()["data"] # Try to be backwards compatible with outdated colab if "text" in js: genout = [utils.getnewcontent(js["text"], self.tokenizer)] else: genout = js["seqs"] return GenerationResult( model=self, out_batches=np.array([self.tokenizer.encode(x) for x in genout]), prompt=prompt_tokens, is_whole_generation=True, single_line=single_line, )