mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Repetition penalty slope and range
This commit is contained in:
@ -1,3 +1,32 @@
|
||||
'''
|
||||
This file is AGPL-licensed.
|
||||
|
||||
Some of the code in this file is from Clover Edition:
|
||||
https://github.com/cloveranon/Clover-Edition/blob/master/aidungeon/gpt2generator.py
|
||||
|
||||
The license for Clover Edition is shown below:
|
||||
|
||||
Copyright (c) 2019 Nick Walton
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
'''
|
||||
|
||||
import multiprocessing
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
|
||||
import progressbar
|
||||
@ -33,6 +62,8 @@ def settings_callback() -> dict:
|
||||
"top_k": 0,
|
||||
"tfs": 1.0,
|
||||
"repetition_penalty": 1.0,
|
||||
"rpslope": 0.0,
|
||||
"rprange": 0,
|
||||
}
|
||||
|
||||
def started_compiling_callback() -> None:
|
||||
@ -78,26 +109,40 @@ def __batch_xmap(shard_dim=1):
|
||||
return inner
|
||||
|
||||
|
||||
def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty):
|
||||
def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
|
||||
'''
|
||||
This gets called by generate_loop_fn to apply repetition penalty
|
||||
to the 1D array logits using the provided 1D array of tokens to penalize
|
||||
'''
|
||||
tokens = np.minimum(tokens, params["n_vocab"]-1) # https://github.com/google/jax/issues/3774
|
||||
rpslope = np.int32(rpslope)
|
||||
rprange = np.int32(rprange)
|
||||
clipped_rprange = rprange if rprange > 0 else tokens.shape[-1]
|
||||
penalty_arange = np.roll(np.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]), generated_index, axis=-1)
|
||||
# Make a new array with the same length as the tokens array but with
|
||||
# each element replaced by the value at the corresponding index in the
|
||||
# logits array; e.g.
|
||||
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
|
||||
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
|
||||
penalty_logits = np.take(logits, tokens)
|
||||
# Repetition penalty slope
|
||||
if rpslope != 0.0 and rprange > 0:
|
||||
_penalty = (penalty_arange/(rprange - 1)) * 2 - 1
|
||||
_penalty = (rpslope * _penalty) / (1 + np.abs(_penalty) * (rpslope - 1))
|
||||
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
|
||||
repetition_penalty = _penalty
|
||||
# Divide positive values by repetition_penalty and multiply negative
|
||||
# values by repetition_penalty (the academic publication that described
|
||||
# this technique actually just only divided, but that would cause tokens
|
||||
# with negative logits to become more likely, which is obviously wrong)
|
||||
penalty_logits = np.where(
|
||||
penalty_logits > 0,
|
||||
penalty_logits/repetition_penalty,
|
||||
penalty_logits*repetition_penalty,
|
||||
penalty_arange >= 0,
|
||||
np.where(
|
||||
penalty_logits > 0,
|
||||
penalty_logits/repetition_penalty,
|
||||
penalty_logits*repetition_penalty,
|
||||
),
|
||||
penalty_logits,
|
||||
)
|
||||
# Finally, put those penalized logit values back into their original
|
||||
# positions in the logits array
|
||||
@ -202,25 +247,46 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
||||
# probability distribution)
|
||||
return jax.random.categorical(key, logits, -1).astype(np.uint32)
|
||||
|
||||
def apply_repetition_penalty_static(logits, tokens, repetition_penalty):
|
||||
def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
|
||||
'''
|
||||
This gets called by generate_loop_fn to apply repetition penalty
|
||||
to the 1D array logits using the provided 1D array of tokens to penalize
|
||||
'''
|
||||
rpslope = jnp.int32(rpslope)
|
||||
rprange = jnp.int32(rprange)
|
||||
clipped_rprange = jax.lax.cond(rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange)
|
||||
penalty_arange = jnp.roll(jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]), generated_index, axis=-1)
|
||||
# Make a new array with the same length as the tokens array but with
|
||||
# each element replaced by the value at the corresponding index in the
|
||||
# logits array; e.g.
|
||||
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
|
||||
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
|
||||
penalty_logits = jnp.take(logits, tokens)
|
||||
# Repetition penalty slope
|
||||
def apply_slope(carry):
|
||||
repetition_penalty, rprange = carry
|
||||
_penalty = (penalty_arange/(rprange - 1)) * 2 - 1
|
||||
_penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1))
|
||||
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
|
||||
return _penalty
|
||||
repetition_penalty = jax.lax.cond(
|
||||
(rpslope != 0.0) & (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash
|
||||
apply_slope,
|
||||
lambda carry: jnp.full(tokens.shape, carry[0]),
|
||||
(repetition_penalty, rprange),
|
||||
)
|
||||
# Divide positive values by repetition_penalty and multiply negative
|
||||
# values by repetition_penalty (the academic publication that described
|
||||
# this technique actually just only divided, but that would cause tokens
|
||||
# with negative logits to become more likely, which is obviously wrong)
|
||||
penalty_logits = jnp.where(
|
||||
penalty_logits > 0,
|
||||
penalty_logits/repetition_penalty,
|
||||
penalty_logits*repetition_penalty,
|
||||
penalty_arange >= 0,
|
||||
jnp.where(
|
||||
penalty_logits > 0,
|
||||
penalty_logits/repetition_penalty,
|
||||
penalty_logits*repetition_penalty,
|
||||
),
|
||||
penalty_logits,
|
||||
)
|
||||
# Finally, put those penalized logit values back into their original
|
||||
# positions in the logits array
|
||||
@ -325,7 +391,7 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
||||
|
||||
pad_token_id = 50256
|
||||
|
||||
def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_options):
|
||||
def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_index, gen_length, rpslope, rprange, sampler_options):
|
||||
numseqs = numseqs_aux.shape[0]
|
||||
gi = data[0][1]
|
||||
def sample_loop_fn(carry):
|
||||
@ -339,7 +405,11 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_op
|
||||
logits = apply_repetition_penalty_dynamic(
|
||||
logits,
|
||||
generated,
|
||||
repetition_penalty
|
||||
repetition_penalty,
|
||||
generated_index,
|
||||
gen_length,
|
||||
rpslope,
|
||||
rprange,
|
||||
)
|
||||
# Remove any tokens in the badwords list by setting
|
||||
# their logits to negative infinity which effectively
|
||||
@ -401,6 +471,8 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||
initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs))
|
||||
# Get repetition penalty from the arguments
|
||||
repetition_penalty = sampler_options.pop('repetition_penalty', None)
|
||||
rpslope = sampler_options.pop('rpslope', None)
|
||||
rprange = sampler_options.pop('rprange', None)
|
||||
# This is the main generation loop
|
||||
def generate_loop_fn(carry):
|
||||
# Unpack current generate_loop_fn state
|
||||
@ -427,7 +499,11 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||
logits = apply_repetition_penalty_static(
|
||||
logits,
|
||||
generated,
|
||||
repetition_penalty
|
||||
repetition_penalty,
|
||||
generated_index,
|
||||
gen_length,
|
||||
rpslope,
|
||||
rprange,
|
||||
)
|
||||
# Remove any tokens in the badwords list by setting
|
||||
# their logits to negative infinity which effectively
|
||||
@ -586,7 +662,9 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||
sample_data[i][2] = logits[i]
|
||||
sampler_options = settings_callback()
|
||||
repetition_penalty = sampler_options.pop("repetition_penalty", 1.0)
|
||||
sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options)
|
||||
rpslope = sampler_options.pop("rpslope", 0.0)
|
||||
rprange = sampler_options.pop("rprange", 0)
|
||||
sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, params["seq"] + n_generated, gen_length, rpslope, rprange, sampler_options)
|
||||
n_generated += 1
|
||||
for i in range(numseqs):
|
||||
generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1))
|
||||
@ -659,6 +737,8 @@ def infer_static(
|
||||
top_k=0,
|
||||
tfs=1.0,
|
||||
repetition_penalty=1.0,
|
||||
rpslope=0.0,
|
||||
rprange=0,
|
||||
numseqs=1,
|
||||
gen_len=80,
|
||||
soft_embeddings: Optional[np.array] = None,
|
||||
@ -679,6 +759,8 @@ def infer_static(
|
||||
"top_p": top_p * np.ones(total_batch),
|
||||
"tfs": tfs * np.ones(total_batch),
|
||||
"repetition_penalty": repetition_penalty * np.ones(total_batch),
|
||||
"rpslope": rpslope * np.ones(total_batch),
|
||||
"rprange": np.full(total_batch, rprange, dtype=np.uint32),
|
||||
"top_k": np.full(total_batch, top_k, dtype=np.uint32)
|
||||
}
|
||||
output = network.generate_static(
|
||||
|
Reference in New Issue
Block a user