mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Allow EOS unbanning
This commit is contained in:
@@ -330,19 +330,39 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if utils.koboldai_vars.use_default_badwordids:
|
||||
self.active_badwordids = self.badwordsids + additional_bad_words_ids
|
||||
else:
|
||||
if additional_bad_words_ids:
|
||||
self.active_badwordids = additional_bad_words_ids
|
||||
else:
|
||||
self.active_badwordids = None
|
||||
|
||||
with torch.no_grad():
|
||||
start_time = time.time()
|
||||
genout = self.model.generate(
|
||||
input_ids=gen_in,
|
||||
do_sample=True,
|
||||
max_length=min(
|
||||
len(prompt_tokens) + max_new, utils.koboldai_vars.max_length
|
||||
),
|
||||
repetition_penalty=1.0,
|
||||
bad_words_ids=self.badwordsids + additional_bad_words_ids,
|
||||
use_cache=True,
|
||||
num_return_sequences=batch_count,
|
||||
)
|
||||
if self.active_badwordids: ## I know duplicating this is ugly, but HF checks if its present and accepts nothing but actual token bans if its there (Which I can't guarantee would be universal enough).... - Henk
|
||||
genout = self.model.generate(
|
||||
input_ids=gen_in,
|
||||
do_sample=True,
|
||||
max_length=min(
|
||||
len(prompt_tokens) + max_new, utils.koboldai_vars.max_length
|
||||
),
|
||||
repetition_penalty=1.0,
|
||||
bad_words_ids=self.active_badwordids,
|
||||
use_cache=True,
|
||||
num_return_sequences=batch_count,
|
||||
)
|
||||
else:
|
||||
genout = self.model.generate(
|
||||
input_ids=gen_in,
|
||||
do_sample=True,
|
||||
max_length=min(
|
||||
len(prompt_tokens) + max_new, utils.koboldai_vars.max_length
|
||||
),
|
||||
repetition_penalty=1.0,
|
||||
use_cache=True,
|
||||
num_return_sequences=batch_count,
|
||||
)
|
||||
logger.debug(
|
||||
"torch_raw_generate: run generator {}s".format(time.time() - start_time)
|
||||
)
|
||||
|
Reference in New Issue
Block a user