Merge branch 'avril' into rep-pen-order

This commit is contained in:
vfbd
2022-08-23 14:47:29 -04:00
3 changed files with 23 additions and 29 deletions

View File

@ -1727,8 +1727,6 @@ def patch_transformers():
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__
RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__
class LuaLogitsProcessor(LogitsProcessor):
@ -1805,6 +1803,7 @@ def patch_transformers():
self.__warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
self.__warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5))
self.__warper_list.append(AdvancedRepetitionPenaltyLogitsProcessor())
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs):
for k in vars.sampler_order:
@ -4617,7 +4616,7 @@ def _generate(txt, minimum, maximum, found_entries):
gen_in,
do_sample=True,
max_length=int(2e9),
repetition_penalty=1.1,
repetition_penalty=1.0,
bad_words_ids=vars.badwordsids,
use_cache=True,
num_return_sequences=numseqs