AvrilAI-style repetition penalty test

This commit is contained in:
Gnome Ann
2022-01-25 15:05:21 -05:00
parent 9356573ac9
commit 2db1f2f7bb
3 changed files with 23 additions and 29 deletions

View File

@ -722,8 +722,6 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__
RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__
class LuaLogitsProcessor(LogitsProcessor):
@ -767,6 +765,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TemperatureLogitsWarper(temperature=0.5))
warper_list.append(AdvancedRepetitionPenaltyLogitsProcessor())
return warper_list
def new_sample(self, *args, **kwargs):
@ -2771,7 +2770,7 @@ def _generate(txt, minimum, maximum, found_entries):
do_sample=True,
min_length=minimum,
max_length=int(2e9),
repetition_penalty=1.1,
repetition_penalty=1.0,
bad_words_ids=vars.badwordsids,
use_cache=True,
num_return_sequences=numseqs