From e874f0c1c26501a0c2592b3acde8a3a271a7c50d Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Mon, 19 Jun 2023 19:05:31 +0200 Subject: [PATCH] Add token streaming support for exllama --- modeling/inference_models/exllama/class.py | 26 ++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/modeling/inference_models/exllama/class.py b/modeling/inference_models/exllama/class.py index 37681b4f..614a3de1 100644 --- a/modeling/inference_models/exllama/class.py +++ b/modeling/inference_models/exllama/class.py @@ -75,6 +75,25 @@ class model_backend(InferenceModel): self.model_name = "" self.path = None + self.post_token_hooks = [ + PostTokenHooks.stream_tokens, + ] + + self.stopper_hooks = [ + Stoppers.core_stopper, + Stoppers.dynamic_wi_scanner, + Stoppers.singleline_stopper, + Stoppers.chat_mode_stopper, + Stoppers.stop_sequence_stopper, + ] + + self.capabilties = ModelCapabilities( + embedding_manipulation=False, + post_token_hooks=True, + stopper_hooks=False, + post_token_probs=False, + ) + def is_valid(self, model_name, model_path, menu_path): gptq_model, _ = load_model_gptq_settings(model_path) try: @@ -265,11 +284,8 @@ class model_backend(InferenceModel): self.generator.gen_begin(gen_in) - # from pudb.remote import set_trace - # set_trace(term_size=(200, 60)) - for i in range(max_new): - logits = self.model.forward(self.generator.sequence[:, -1:], self.cache) + logits = self.model.forward(self.generator.sequence[:, -1:], self.generator.cache) logits[:, :, self.tokenizer.bos_token_id] = -10000.0 logits = torch.unsqueeze(logits[0, -1, :], 0) @@ -282,6 +298,8 @@ class model_backend(InferenceModel): self.generator.gen_accept_token(token) + self._post_token_gen(self.generator.sequence) + if token.item() == self.tokenizer.eos_token_id: break return GenerationResult(