AvrilAI-style repetition penalty test

This commit is contained in:
Gnome Ann 2022-01-25 15:05:21 -05:00
parent 9356573ac9
commit 2db1f2f7bb
3 changed files with 23 additions and 29 deletions

View File

@ -722,8 +722,6 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__
RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__
class LuaLogitsProcessor(LogitsProcessor):
@ -767,6 +765,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TemperatureLogitsWarper(temperature=0.5))
warper_list.append(AdvancedRepetitionPenaltyLogitsProcessor())
return warper_list
def new_sample(self, *args, **kwargs):
@ -2771,7 +2770,7 @@ def _generate(txt, minimum, maximum, found_entries):
do_sample=True,
min_length=minimum,
max_length=int(2e9),
repetition_penalty=1.1,
repetition_penalty=1.0,
bad_words_ids=vars.badwordsids,
use_cache=True,
num_return_sequences=numseqs

View File

@ -149,7 +149,7 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat
logits[tokens] = penalty_logits
return logits
def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
def kobold_sample_dynamic(key, logits, rpargs, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
'''
This gets called by generate_loop_fn to apply a series of 4 filters
to the logits (top-k, then top-p, then TFS, then temperature) before
@ -245,6 +245,7 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
# Finally, pick one token using the softmax thingy again (it gives
# an array whose elements sum to 1 so it can be used nicely as a
# probability distribution)
logits = apply_repetition_penalty_dynamic(logits, *rpargs)
return jax.random.categorical(key, logits, -1).astype(np.uint32)
def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
@ -292,7 +293,7 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate
# positions in the logits array
return logits.at[tokens].set(penalty_logits)
def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
def kobold_sample_static(key, logits, rpargs, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
'''
This gets called by generate_loop_fn to apply a series of 4 filters
to the logits (top-k, then top-p, then TFS, then temperature) before
@ -387,6 +388,7 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
# Finally, pick one token using the softmax thingy again (it gives
# an array whose elements sum to 1 so it can be used nicely as a
# probability distribution)
logits = apply_repetition_penalty_static(logits, *rpargs)
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)
pad_token_id = 50256
@ -400,17 +402,6 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
# Get the pseudo-random number generator key that will
# be used by kobold_sample_dynamic to randomly pick a token
sample_key, new_key = jax.random.split(sample_key, num=2)
# Apply repetition penalty to all tokens that are
# currently inside the "generated" array
logits = apply_repetition_penalty_dynamic(
logits,
generated,
repetition_penalty,
generated_index,
gen_length,
rpslope,
rprange,
)
# Remove any tokens in the badwords list by setting
# their logits to negative infinity which effectively
# makes their probabilities of being chosen zero
@ -422,6 +413,14 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
next_token = kobold_sample_dynamic(
sample_key,
logits,
(
generated,
repetition_penalty,
generated_index,
gen_length,
rpslope,
rprange,
)
**sampler_options,
)
# Remember what token was picked
@ -493,18 +492,6 @@ class PenalizingCausalTransformer(CausalTransformer):
assert logits.shape == (1, config["n_vocab"])
# Flatten it into a 1D array to make it easier to use
logits = logits[0]
# Apply repetition penalty to all tokens that are
# currently inside the "generated" array
if repetition_penalty is not None:
logits = apply_repetition_penalty_static(
logits,
generated,
repetition_penalty,
generated_index,
gen_length,
rpslope,
rprange,
)
# Remove any tokens in the badwords list by setting
# their logits to negative infinity which effectively
# makes their probabilities of being chosen zero
@ -516,6 +503,14 @@ class PenalizingCausalTransformer(CausalTransformer):
next_token = kobold_sample_static(
sample_key,
logits,
(
generated,
repetition_penalty,
generated_index,
gen_length,
rpslope,
rprange,
),
**sampler_options,
)
# Remember what token was picked

View File

@ -31,7 +31,7 @@ import torch
from transformers import LogitsWarper, LogitsProcessor
class AdvancedRepetitionPenaltyLogitsProcessor(LogitsProcessor):
class AdvancedRepetitionPenaltyLogitsProcessor(LogitsWarper):
def __init__(self, *args, **kwargs):
pass