Repetition penalty slope and range

This commit is contained in:
Gnome Ann
2022-01-24 15:30:38 -05:00
parent e69265cb4f
commit 3f18888eec
7 changed files with 288 additions and 60 deletions

View File

@ -1,3 +1,32 @@
'''
This file is AGPL-licensed.
Some of the code in this file is from Clover Edition:
https://github.com/cloveranon/Clover-Edition/blob/master/aidungeon/gpt2generator.py
The license for Clover Edition is shown below:
Copyright (c) 2019 Nick Walton
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
'''
import multiprocessing
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
import progressbar
@ -33,6 +62,8 @@ def settings_callback() -> dict:
"top_k": 0,
"tfs": 1.0,
"repetition_penalty": 1.0,
"rpslope": 0.0,
"rprange": 0,
}
def started_compiling_callback() -> None:
@ -78,26 +109,40 @@ def __batch_xmap(shard_dim=1):
return inner
def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty):
def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
'''
This gets called by generate_loop_fn to apply repetition penalty
to the 1D array logits using the provided 1D array of tokens to penalize
'''
tokens = np.minimum(tokens, params["n_vocab"]-1) # https://github.com/google/jax/issues/3774
rpslope = np.int32(rpslope)
rprange = np.int32(rprange)
clipped_rprange = rprange if rprange > 0 else tokens.shape[-1]
penalty_arange = np.roll(np.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]), generated_index, axis=-1)
# Make a new array with the same length as the tokens array but with
# each element replaced by the value at the corresponding index in the
# logits array; e.g.
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
penalty_logits = np.take(logits, tokens)
# Repetition penalty slope
if rpslope != 0.0 and rprange > 0:
_penalty = (penalty_arange/(rprange - 1)) * 2 - 1
_penalty = (rpslope * _penalty) / (1 + np.abs(_penalty) * (rpslope - 1))
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
repetition_penalty = _penalty
# Divide positive values by repetition_penalty and multiply negative
# values by repetition_penalty (the academic publication that described
# this technique actually just only divided, but that would cause tokens
# with negative logits to become more likely, which is obviously wrong)
penalty_logits = np.where(
penalty_logits > 0,
penalty_logits/repetition_penalty,
penalty_logits*repetition_penalty,
penalty_arange >= 0,
np.where(
penalty_logits > 0,
penalty_logits/repetition_penalty,
penalty_logits*repetition_penalty,
),
penalty_logits,
)
# Finally, put those penalized logit values back into their original
# positions in the logits array
@ -202,25 +247,46 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
# probability distribution)
return jax.random.categorical(key, logits, -1).astype(np.uint32)
def apply_repetition_penalty_static(logits, tokens, repetition_penalty):
def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
'''
This gets called by generate_loop_fn to apply repetition penalty
to the 1D array logits using the provided 1D array of tokens to penalize
'''
rpslope = jnp.int32(rpslope)
rprange = jnp.int32(rprange)
clipped_rprange = jax.lax.cond(rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange)
penalty_arange = jnp.roll(jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]), generated_index, axis=-1)
# Make a new array with the same length as the tokens array but with
# each element replaced by the value at the corresponding index in the
# logits array; e.g.
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
penalty_logits = jnp.take(logits, tokens)
# Repetition penalty slope
def apply_slope(carry):
repetition_penalty, rprange = carry
_penalty = (penalty_arange/(rprange - 1)) * 2 - 1
_penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1))
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
return _penalty
repetition_penalty = jax.lax.cond(
(rpslope != 0.0) & (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash
apply_slope,
lambda carry: jnp.full(tokens.shape, carry[0]),
(repetition_penalty, rprange),
)
# Divide positive values by repetition_penalty and multiply negative
# values by repetition_penalty (the academic publication that described
# this technique actually just only divided, but that would cause tokens
# with negative logits to become more likely, which is obviously wrong)
penalty_logits = jnp.where(
penalty_logits > 0,
penalty_logits/repetition_penalty,
penalty_logits*repetition_penalty,
penalty_arange >= 0,
jnp.where(
penalty_logits > 0,
penalty_logits/repetition_penalty,
penalty_logits*repetition_penalty,
),
penalty_logits,
)
# Finally, put those penalized logit values back into their original
# positions in the logits array
@ -325,7 +391,7 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
pad_token_id = 50256
def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_options):
def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_index, gen_length, rpslope, rprange, sampler_options):
numseqs = numseqs_aux.shape[0]
gi = data[0][1]
def sample_loop_fn(carry):
@ -339,7 +405,11 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_op
logits = apply_repetition_penalty_dynamic(
logits,
generated,
repetition_penalty
repetition_penalty,
generated_index,
gen_length,
rpslope,
rprange,
)
# Remove any tokens in the badwords list by setting
# their logits to negative infinity which effectively
@ -401,6 +471,8 @@ class PenalizingCausalTransformer(CausalTransformer):
initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs))
# Get repetition penalty from the arguments
repetition_penalty = sampler_options.pop('repetition_penalty', None)
rpslope = sampler_options.pop('rpslope', None)
rprange = sampler_options.pop('rprange', None)
# This is the main generation loop
def generate_loop_fn(carry):
# Unpack current generate_loop_fn state
@ -427,7 +499,11 @@ class PenalizingCausalTransformer(CausalTransformer):
logits = apply_repetition_penalty_static(
logits,
generated,
repetition_penalty
repetition_penalty,
generated_index,
gen_length,
rpslope,
rprange,
)
# Remove any tokens in the badwords list by setting
# their logits to negative infinity which effectively
@ -586,7 +662,9 @@ class PenalizingCausalTransformer(CausalTransformer):
sample_data[i][2] = logits[i]
sampler_options = settings_callback()
repetition_penalty = sampler_options.pop("repetition_penalty", 1.0)
sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options)
rpslope = sampler_options.pop("rpslope", 0.0)
rprange = sampler_options.pop("rprange", 0)
sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, params["seq"] + n_generated, gen_length, rpslope, rprange, sampler_options)
n_generated += 1
for i in range(numseqs):
generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1))
@ -659,6 +737,8 @@ def infer_static(
top_k=0,
tfs=1.0,
repetition_penalty=1.0,
rpslope=0.0,
rprange=0,
numseqs=1,
gen_len=80,
soft_embeddings: Optional[np.array] = None,
@ -679,6 +759,8 @@ def infer_static(
"top_p": top_p * np.ones(total_batch),
"tfs": tfs * np.ones(total_batch),
"repetition_penalty": repetition_penalty * np.ones(total_batch),
"rpslope": rpslope * np.ones(total_batch),
"rprange": np.full(total_batch, rprange, dtype=np.uint32),
"top_k": np.full(total_batch, top_k, dtype=np.uint32)
}
output = network.generate_static(