Hook up use_default_badwordids in exllama

Use the value of the use_default_badwordids setting to configure
bad_words_ids. Also add square brackets to bad_words_ids if the
use_default_badwordids setting is True. Fix an issue with
attempting to use the tokenizer too early, and fix an exception
populating Lua bridge data when zero tokens are generated, which
can now happen if use_default_badwordids is False and the first
token generated is EOS.
This commit is contained in:
Llama
2023-08-29 23:08:51 -07:00
parent 36f53cc915
commit d6ed75f993
2 changed files with 13 additions and 18 deletions

View File

@@ -3918,7 +3918,8 @@ def generate(txt, minimum, maximum, found_entries=None, gen_mode=GenerationMode.
return
for i in range(koboldai_vars.numseqs):
koboldai_vars.lua_koboldbridge.generated[i+1][koboldai_vars.generated_tkns] = int(genout[i, -1].item())
if len(genout[i]) > 0:
koboldai_vars.lua_koboldbridge.generated[i+1][koboldai_vars.generated_tkns] = int(genout[i, -1].item())
koboldai_vars.lua_koboldbridge.outputs[i+1] = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
execute_outmod()

View File

@@ -102,9 +102,6 @@ class model_backend(InferenceModel):
post_token_probs=False,
)
# We need to wait until the tokenizer is available to fill this in.
self.badwordsids = []
def is_valid(self, model_name, model_path, menu_path):
gptq_model, _ = load_model_gptq_settings(model_path)
try:
@@ -129,7 +126,6 @@ class model_backend(InferenceModel):
self.model = self._get_model(self.get_local_model_path(), {})
self.tokenizer = self._get_tokenizer(self.get_local_model_path())
self.badwordsids = [self.tokenizer.bos_token_id, self.tokenizer.eos_token_id]
self.cache = ExLlamaCache(self.model)
self.generator = ExLlamaGenerator(self.model, self.tokenizer.tokenizer, self.cache)
@@ -221,6 +217,8 @@ class model_backend(InferenceModel):
# Cache the newline token (for single line mode)
# Since there is only one Llama token containing newline, just encode \n
self.newline_tokens = self.tokenizer.encode("\n")
self.bracket_tokens = [i for i, tok in enumerate(vocab) if '[' in tok or ']' in tok]
self.tokenizer._koboldai_header = self.tokenizer.encode("")
def unload(self):
self.model_config = None
@@ -290,9 +288,12 @@ class model_backend(InferenceModel):
if seed:
torch.manual_seed(seed)
bad_words_ids = self.badwordsids
bad_words_ids = [self.tokenizer.bos_token_id]
if utils.koboldai_vars.use_default_badwordids:
bad_words_ids.append(self.tokenizer.eos_token_id)
bad_words_ids.extend(self.bracket_tokens)
if single_line:
bad_words_ids = list(bad_words_ids) + self.newline_tokens
bad_words_ids.extend(self.newline_tokens)
if not isinstance(prompt_tokens, torch.Tensor):
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
@@ -301,7 +302,6 @@ class model_backend(InferenceModel):
self.generator.gen_begin_reuse(gen_in)
trim_count = 0
for i in range(max_new):
logits = self.model.forward(self.generator.sequence[:, -1:], self.generator.cache)
for bad_word_id in bad_words_ids:
@@ -322,16 +322,15 @@ class model_backend(InferenceModel):
if (scores.gather(1, token) > 0).all():
break
if (token == self.tokenizer.eos_token_id).any():
break
self.generator.gen_accept_token(token)
self._post_token_gen(self.generator.sequence)
utils.koboldai_vars.generated_tkns += 1
if (token == self.tokenizer.eos_token_id).any():
trim_count = 1
break
# Apply stoppers
do_stop = False
for stopper in self.stopper_hooks:
@@ -341,11 +340,7 @@ class model_backend(InferenceModel):
if do_stop:
break
utils.koboldai_vars.generated_tkns = max_new - trim_count
if trim_count > 0:
seq = self.generator.sequence[:, gen_in.size(1):-trim_count]
else:
seq = self.generator.sequence[:, gen_in.size(1):]
seq = self.generator.sequence[:, gen_in.size(1):]
return GenerationResult(
model=self,
@@ -365,7 +360,6 @@ class model_backend(InferenceModel):
def _get_tokenizer(self, location: str):
tokenizer = GenericTokenizer(LlamaTokenizer.from_pretrained(location))
tokenizer._koboldai_header = tokenizer.encode("")
return tokenizer
def get_requested_parameters(self, model_name, model_path, menu_path, parameters = {}):