From 2db1f2f7bb4dea89fb69aff93f4f1207f2974ace Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 25 Jan 2022 15:05:21 -0500 Subject: [PATCH] AvrilAI-style repetition penalty test --- aiserver.py | 5 ++--- tpu_mtj_backend.py | 45 ++++++++++++++++++++------------------------- warpers.py | 2 +- 3 files changed, 23 insertions(+), 29 deletions(-) diff --git a/aiserver.py b/aiserver.py index 64470d1e..5be7a17a 100644 --- a/aiserver.py +++ b/aiserver.py @@ -722,8 +722,6 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0) dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0) dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0) - RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__ - RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__ class LuaLogitsProcessor(LogitsProcessor): @@ -767,6 +765,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TemperatureLogitsWarper(temperature=0.5)) + warper_list.append(AdvancedRepetitionPenaltyLogitsProcessor()) return warper_list def new_sample(self, *args, **kwargs): @@ -2771,7 +2770,7 @@ def _generate(txt, minimum, maximum, found_entries): do_sample=True, min_length=minimum, max_length=int(2e9), - repetition_penalty=1.1, + repetition_penalty=1.0, bad_words_ids=vars.badwordsids, use_cache=True, num_return_sequences=numseqs diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 653f8cf1..e7632eba 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -149,7 +149,7 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat logits[tokens] = penalty_logits return logits -def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): +def kobold_sample_dynamic(key, logits, rpargs, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): ''' This gets called by generate_loop_fn to apply a series of 4 filters to the logits (top-k, then top-p, then TFS, then temperature) before @@ -245,6 +245,7 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) + logits = apply_repetition_penalty_dynamic(logits, *rpargs) return jax.random.categorical(key, logits, -1).astype(np.uint32) def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange): @@ -292,7 +293,7 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate # positions in the logits array return logits.at[tokens].set(penalty_logits) -def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): +def kobold_sample_static(key, logits, rpargs, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): ''' This gets called by generate_loop_fn to apply a series of 4 filters to the logits (top-k, then top-p, then TFS, then temperature) before @@ -387,6 +388,7 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) + logits = apply_repetition_penalty_static(logits, *rpargs) return jax.random.categorical(key, logits, -1).astype(jnp.uint32) pad_token_id = 50256 @@ -400,17 +402,6 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_ # Get the pseudo-random number generator key that will # be used by kobold_sample_dynamic to randomly pick a token sample_key, new_key = jax.random.split(sample_key, num=2) - # Apply repetition penalty to all tokens that are - # currently inside the "generated" array - logits = apply_repetition_penalty_dynamic( - logits, - generated, - repetition_penalty, - generated_index, - gen_length, - rpslope, - rprange, - ) # Remove any tokens in the badwords list by setting # their logits to negative infinity which effectively # makes their probabilities of being chosen zero @@ -422,6 +413,14 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_ next_token = kobold_sample_dynamic( sample_key, logits, + ( + generated, + repetition_penalty, + generated_index, + gen_length, + rpslope, + rprange, + ) **sampler_options, ) # Remember what token was picked @@ -493,18 +492,6 @@ class PenalizingCausalTransformer(CausalTransformer): assert logits.shape == (1, config["n_vocab"]) # Flatten it into a 1D array to make it easier to use logits = logits[0] - # Apply repetition penalty to all tokens that are - # currently inside the "generated" array - if repetition_penalty is not None: - logits = apply_repetition_penalty_static( - logits, - generated, - repetition_penalty, - generated_index, - gen_length, - rpslope, - rprange, - ) # Remove any tokens in the badwords list by setting # their logits to negative infinity which effectively # makes their probabilities of being chosen zero @@ -516,6 +503,14 @@ class PenalizingCausalTransformer(CausalTransformer): next_token = kobold_sample_static( sample_key, logits, + ( + generated, + repetition_penalty, + generated_index, + gen_length, + rpslope, + rprange, + ), **sampler_options, ) # Remember what token was picked diff --git a/warpers.py b/warpers.py index 07670f6d..122bc1cd 100644 --- a/warpers.py +++ b/warpers.py @@ -31,7 +31,7 @@ import torch from transformers import LogitsWarper, LogitsProcessor -class AdvancedRepetitionPenaltyLogitsProcessor(LogitsProcessor): +class AdvancedRepetitionPenaltyLogitsProcessor(LogitsWarper): def __init__(self, *args, **kwargs): pass