From 39dfb1845570718d31490273bcb008718419b54e Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Tue, 6 Jun 2023 19:21:34 +0200 Subject: [PATCH] Replace exllama samplers with kobold's inbuilt ones --- modeling/inference_models/exllama/class.py | 56 +++++++++++++++++++++- 1 file changed, 54 insertions(+), 2 deletions(-) diff --git a/modeling/inference_models/exllama/class.py b/modeling/inference_models/exllama/class.py index db1728cf..3ff38d33 100644 --- a/modeling/inference_models/exllama/class.py +++ b/modeling/inference_models/exllama/class.py @@ -15,6 +15,10 @@ import gc import utils from logger import logger +from modeling import warpers +from modeling.warpers import Warper +from modeling.stoppers import Stoppers +from modeling.post_token_hooks import PostTokenHooks from modeling.inference_model import ( GenerationResult, GenerationSettings, @@ -30,6 +34,11 @@ from exllama.generator import ExLlamaGenerator model_backend_name = "ExLlama" +# When set to true, messages will appear in the console if samplers are not +# changing the scores. Keep in mind some samplers don't always change the +# scores for each token. +LOG_SAMPLER_NO_EFFECT = False + def load_model_gptq_settings(path): try: @@ -86,7 +95,7 @@ class model_backend(InferenceModel): def _load(self, save_model: bool, initial_load: bool) -> None: self.model = self._get_model(self.get_local_model_path(), {}) - self.tokenizer = self._get_tokenizer(self.get_local_model_path())) + self.tokenizer = self._get_tokenizer(self.get_local_model_path()) self.cache = ExLlamaCache(self.model) @@ -203,6 +212,34 @@ class model_backend(InferenceModel): except: pass + def _apply_warpers( + self, scores: torch.Tensor, input_ids: torch.Tensor + ) -> torch.Tensor: + warpers.update_settings() + + if LOG_SAMPLER_NO_EFFECT: + pre = torch.Tensor(scores) + + for sid in utils.koboldai_vars.sampler_order: + warper = Warper.from_id(sid) + + if not warper.value_is_valid(): + continue + + if warper == warpers.RepetitionPenalty: + # Rep pen needs more data than other samplers + scores = warper.torch(scores, input_ids=input_ids) + else: + scores = warper.torch(scores) + + assert scores is not None, f"Scores are None; warper '{warper}' is to blame" + + if LOG_SAMPLER_NO_EFFECT: + if torch.equal(pre, scores): + logger.info(warper, "had no effect on the scores.") + pre = torch.Tensor(scores) + return scores + def _raw_generate( self, prompt_tokens: Union[List[int], torch.Tensor], @@ -228,8 +265,23 @@ class model_backend(InferenceModel): self.generator.gen_begin(gen_in) + # from pudb.remote import set_trace + # set_trace(term_size=(200, 60)) + for i in range(max_new): - token = self.generator.gen_single_token() + logits = self.model.forward(self.generator.sequence[:, -1:], self.cache) + logits[:, :, self.tokenizer.bos_token_id] = -10000.0 + + logits = torch.unsqueeze(logits[0, -1, :], 0) + + scores = self._apply_warpers(logits, gen_in) + + scores = torch.softmax(scores, dim=-1) + + token = torch.multinomial(scores, 1) + + self.generator.gen_accept_token(token) + if token.item() == self.tokenizer.eos_token_id: break return GenerationResult(