Allow EOS unbanning

This commit is contained in:
Henk
2023-08-29 20:51:09 +02:00
parent d77acf17eb
commit 49fa63052f
4 changed files with 52 additions and 12 deletions

View File

@@ -330,19 +330,39 @@ class HFTorchInferenceModel(HFInferenceModel):
if seed is not None:
torch.manual_seed(seed)
if utils.koboldai_vars.use_default_badwordids:
self.active_badwordids = self.badwordsids + additional_bad_words_ids
else:
if additional_bad_words_ids:
self.active_badwordids = additional_bad_words_ids
else:
self.active_badwordids = None
with torch.no_grad():
start_time = time.time()
genout = self.model.generate(
input_ids=gen_in,
do_sample=True,
max_length=min(
len(prompt_tokens) + max_new, utils.koboldai_vars.max_length
),
repetition_penalty=1.0,
bad_words_ids=self.badwordsids + additional_bad_words_ids,
use_cache=True,
num_return_sequences=batch_count,
)
if self.active_badwordids: ## I know duplicating this is ugly, but HF checks if its present and accepts nothing but actual token bans if its there (Which I can't guarantee would be universal enough).... - Henk
genout = self.model.generate(
input_ids=gen_in,
do_sample=True,
max_length=min(
len(prompt_tokens) + max_new, utils.koboldai_vars.max_length
),
repetition_penalty=1.0,
bad_words_ids=self.active_badwordids,
use_cache=True,
num_return_sequences=batch_count,
)
else:
genout = self.model.generate(
input_ids=gen_in,
do_sample=True,
max_length=min(
len(prompt_tokens) + max_new, utils.koboldai_vars.max_length
),
repetition_penalty=1.0,
use_cache=True,
num_return_sequences=batch_count,
)
logger.debug(
"torch_raw_generate: run generator {}s".format(time.time() - start_time)
)