From 08ff7c138c35e344819acbd82fa18e88732e08a4 Mon Sep 17 00:00:00 2001 From: Llama <34464159+pi6am@users.noreply.github.com> Date: Sun, 27 Aug 2023 16:34:52 -0700 Subject: [PATCH] Add the eos token to exllama bad words. The bos token was already hardcoded as a bad word id. Store badwords in a list and iterate over them during generation. Add the Llama eos token to the list of bad words. Also support "single line mode", which adds newline (13) to badwords. --- modeling/inference_models/exllama/class.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/modeling/inference_models/exllama/class.py b/modeling/inference_models/exllama/class.py index 3fb8d252..737afa88 100644 --- a/modeling/inference_models/exllama/class.py +++ b/modeling/inference_models/exllama/class.py @@ -95,6 +95,9 @@ class model_backend(InferenceModel): post_token_probs=False, ) + # We need to wait until the tokenizer is available to fill this in. + self.badwordsids = [] + def is_valid(self, model_name, model_path, menu_path): gptq_model, _ = load_model_gptq_settings(model_path) try: @@ -119,6 +122,7 @@ class model_backend(InferenceModel): self.model = self._get_model(self.get_local_model_path(), {}) self.tokenizer = self._get_tokenizer(self.get_local_model_path()) + self.badwordsids = [self.tokenizer.bos_token_id, self.tokenizer.eos_token_id] self.cache = ExLlamaCache(self.model) self.generator = ExLlamaGenerator(self.model, self.tokenizer.tokenizer, self.cache) @@ -207,6 +211,10 @@ class model_backend(InferenceModel): return result object.__setattr__(self.tokenizer, '__call__', call_wrapper.__get__(self.tokenizer)) + # Cache the newline token (for single line mode) + # Since there is only one Llama token containing newline, just encode \n + self.newline_tokens = self.tokenizer.encode("\n") + def unload(self): self.model_config = None @@ -275,6 +283,10 @@ class model_backend(InferenceModel): if seed: torch.manual_seed(seed) + bad_words_ids = self.badwordsids + if single_line: + bad_words_ids = list(bad_words_ids) + self.newline_tokens + if not isinstance(prompt_tokens, torch.Tensor): gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None] else: @@ -285,7 +297,8 @@ class model_backend(InferenceModel): trim_count = 0 for i in range(max_new): logits = self.model.forward(self.generator.sequence[:, -1:], self.generator.cache) - logits[:, :, self.tokenizer.bos_token_id] = -10000.0 + for bad_word_id in bad_words_ids: + logits[:, :, bad_word_id] = -10000.0 logits = torch.unsqueeze(logits[0, -1, :], 0)