Add the eos token to exllama bad words.

The bos token was already hardcoded as a bad word id.
Store badwords in a list and iterate over them during generation.
Add the Llama eos token to the list of bad words.
Also support "single line mode", which adds newline (13) to badwords.
This commit is contained in:
Llama
2023-08-27 16:34:52 -07:00
parent 0d150e412e
commit 08ff7c138c

View File

@@ -95,6 +95,9 @@ class model_backend(InferenceModel):
post_token_probs=False,
)
# We need to wait until the tokenizer is available to fill this in.
self.badwordsids = []
def is_valid(self, model_name, model_path, menu_path):
gptq_model, _ = load_model_gptq_settings(model_path)
try:
@@ -119,6 +122,7 @@ class model_backend(InferenceModel):
self.model = self._get_model(self.get_local_model_path(), {})
self.tokenizer = self._get_tokenizer(self.get_local_model_path())
self.badwordsids = [self.tokenizer.bos_token_id, self.tokenizer.eos_token_id]
self.cache = ExLlamaCache(self.model)
self.generator = ExLlamaGenerator(self.model, self.tokenizer.tokenizer, self.cache)
@@ -207,6 +211,10 @@ class model_backend(InferenceModel):
return result
object.__setattr__(self.tokenizer, '__call__', call_wrapper.__get__(self.tokenizer))
# Cache the newline token (for single line mode)
# Since there is only one Llama token containing newline, just encode \n
self.newline_tokens = self.tokenizer.encode("\n")
def unload(self):
self.model_config = None
@@ -275,6 +283,10 @@ class model_backend(InferenceModel):
if seed:
torch.manual_seed(seed)
bad_words_ids = self.badwordsids
if single_line:
bad_words_ids = list(bad_words_ids) + self.newline_tokens
if not isinstance(prompt_tokens, torch.Tensor):
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
else:
@@ -285,7 +297,8 @@ class model_backend(InferenceModel):
trim_count = 0
for i in range(max_new):
logits = self.model.forward(self.generator.sequence[:, -1:], self.generator.cache)
logits[:, :, self.tokenizer.bos_token_id] = -10000.0
for bad_word_id in bad_words_ids:
logits[:, :, bad_word_id] = -10000.0
logits = torch.unsqueeze(logits[0, -1, :], 0)