mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Add the eos token to exllama bad words.
The bos token was already hardcoded as a bad word id. Store badwords in a list and iterate over them during generation. Add the Llama eos token to the list of bad words. Also support "single line mode", which adds newline (13) to badwords.
This commit is contained in:
@@ -95,6 +95,9 @@ class model_backend(InferenceModel):
|
||||
post_token_probs=False,
|
||||
)
|
||||
|
||||
# We need to wait until the tokenizer is available to fill this in.
|
||||
self.badwordsids = []
|
||||
|
||||
def is_valid(self, model_name, model_path, menu_path):
|
||||
gptq_model, _ = load_model_gptq_settings(model_path)
|
||||
try:
|
||||
@@ -119,6 +122,7 @@ class model_backend(InferenceModel):
|
||||
self.model = self._get_model(self.get_local_model_path(), {})
|
||||
self.tokenizer = self._get_tokenizer(self.get_local_model_path())
|
||||
|
||||
self.badwordsids = [self.tokenizer.bos_token_id, self.tokenizer.eos_token_id]
|
||||
self.cache = ExLlamaCache(self.model)
|
||||
|
||||
self.generator = ExLlamaGenerator(self.model, self.tokenizer.tokenizer, self.cache)
|
||||
@@ -207,6 +211,10 @@ class model_backend(InferenceModel):
|
||||
return result
|
||||
object.__setattr__(self.tokenizer, '__call__', call_wrapper.__get__(self.tokenizer))
|
||||
|
||||
# Cache the newline token (for single line mode)
|
||||
# Since there is only one Llama token containing newline, just encode \n
|
||||
self.newline_tokens = self.tokenizer.encode("\n")
|
||||
|
||||
def unload(self):
|
||||
self.model_config = None
|
||||
|
||||
@@ -275,6 +283,10 @@ class model_backend(InferenceModel):
|
||||
if seed:
|
||||
torch.manual_seed(seed)
|
||||
|
||||
bad_words_ids = self.badwordsids
|
||||
if single_line:
|
||||
bad_words_ids = list(bad_words_ids) + self.newline_tokens
|
||||
|
||||
if not isinstance(prompt_tokens, torch.Tensor):
|
||||
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
|
||||
else:
|
||||
@@ -285,7 +297,8 @@ class model_backend(InferenceModel):
|
||||
trim_count = 0
|
||||
for i in range(max_new):
|
||||
logits = self.model.forward(self.generator.sequence[:, -1:], self.generator.cache)
|
||||
logits[:, :, self.tokenizer.bos_token_id] = -10000.0
|
||||
for bad_word_id in bad_words_ids:
|
||||
logits[:, :, bad_word_id] = -10000.0
|
||||
|
||||
logits = torch.unsqueeze(logits[0, -1, :], 0)
|
||||
|
||||
|
Reference in New Issue
Block a user