mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: Samplers pt. 1
This commit is contained in:
499
warpers.py
499
warpers.py
@@ -1,4 +1,4 @@
|
||||
'''
|
||||
"""
|
||||
This file is AGPL-licensed.
|
||||
|
||||
Some of the code in this file is from Clover Edition:
|
||||
@@ -25,54 +25,146 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
'''
|
||||
|
||||
---
|
||||
|
||||
Some of the code in this file is also from Hugging Face logitsTransformers:
|
||||
https://github.com/huggingface/transformers
|
||||
|
||||
Transformers is licensed under the Apache-2.0 License. The changes made to this
|
||||
file are mostly porting warper code to the torch methods.
|
||||
"""
|
||||
# Comments mostly taken from tpu_mtj_backend.py
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from transformers import LogitsWarper
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import tpu_mtj_backend
|
||||
|
||||
|
||||
class AdvancedRepetitionPenaltyLogitsProcessor(LogitsWarper):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
class Warper:
|
||||
@staticmethod
|
||||
def from_id(warper_id: int) -> Warper:
|
||||
return {
|
||||
0: TopK,
|
||||
1: TopA,
|
||||
2: TopP,
|
||||
3: TailFree,
|
||||
4: Typical,
|
||||
5: Temperature,
|
||||
6: RepetitionPenalty,
|
||||
}[warper_id]
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
self.penalty_range = int(self.penalty_range)
|
||||
clipped_penalty_range = min(input_ids.shape[-1], self.penalty_range)
|
||||
|
||||
if self.penalty != 1.0:
|
||||
if self.penalty_range > 0:
|
||||
if clipped_penalty_range < input_ids.shape[1]:
|
||||
input_ids = input_ids[..., -clipped_penalty_range:]
|
||||
class Temperature(Warper):
|
||||
"""Temperature (just divide the logits by the temperature)"""
|
||||
|
||||
if self.penalty_slope != 0:
|
||||
_penalty = (torch.arange(self.penalty_range, dtype=scores.dtype, device=scores.device)/(self.penalty_range - 1)) * 2. - 1
|
||||
_penalty = (self.penalty_slope * _penalty) / (1 + torch.abs(_penalty) * (self.penalty_slope - 1))
|
||||
_penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (self.penalty - 1)
|
||||
self.penalty = _penalty[..., -clipped_penalty_range:]
|
||||
temperature: float = 0.5
|
||||
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
if self.use_alt_rep_pen:
|
||||
score = score - torch.log(self.penalty)
|
||||
else:
|
||||
score = torch.where(score <= 0, score * self.penalty, score / self.penalty)
|
||||
scores.scatter_(1, input_ids, score)
|
||||
@classmethod
|
||||
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
|
||||
return scores / cls.temperature
|
||||
|
||||
@classmethod
|
||||
def jax(cls, scores: jnp.array) -> jnp.array:
|
||||
return scores / cls.value
|
||||
|
||||
|
||||
class TopP(Warper):
|
||||
"""
|
||||
Top-p (after sorting the remaining tokens again in descending order of
|
||||
logit, remove the ones that have cumulative softmax probability
|
||||
greater than p)
|
||||
"""
|
||||
|
||||
top_p: float = 0.9
|
||||
|
||||
@classmethod
|
||||
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
|
||||
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
|
||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs <= (1 - cls.value)
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
return scores.masked_fill(indices_to_remove, -np.inf)
|
||||
|
||||
@classmethod
|
||||
def jax(cls, scores: jnp.array) -> jnp.array:
|
||||
# Sort the logits array in descending order, replace every element
|
||||
# with e (Euler's number) to the power of that element, and divide
|
||||
# each element of the new array by the sum of the elements in the
|
||||
# new array
|
||||
sorted_logits = -jnp.sort(-scores)
|
||||
probabilities = jax.nn.softmax(sorted_logits)
|
||||
# Calculate cumulative_probabilities as the prefix-sum array of
|
||||
# probabilities
|
||||
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1)
|
||||
# We want to remove tokens with cumulative probability higher
|
||||
# than top_p
|
||||
sorted_indices_to_remove = cumulative_probabilities > cls.value
|
||||
# Don't ever remove the token with the highest logit, even if
|
||||
# the probability is higher than top_p
|
||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||
# Unsort and remove
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(-scores),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||
|
||||
|
||||
class TopK(Warper):
|
||||
"""
|
||||
Top-k (keep only the k tokens with the highest logits and remove the rest,
|
||||
by setting their logits to negative infinity)
|
||||
"""
|
||||
|
||||
top_k: int = 0
|
||||
|
||||
@classmethod
|
||||
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
|
||||
top_k = min(max(cls.top_k, 1), scores.size(-1)) # Safety check
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
|
||||
scores = scores.masked_fill(indices_to_remove, -np.inf)
|
||||
return scores
|
||||
|
||||
@classmethod
|
||||
def jax(cls, scores: jnp.array) -> jnp.array:
|
||||
# After sorting the logits array in descending order,
|
||||
# sorted_indices_to_remove is a 1D array that is True for tokens
|
||||
# in the sorted logits array we want to remove and False for ones
|
||||
# we want to keep, in this case the first top_k elements will be
|
||||
# False and the rest will be True
|
||||
sorted_indices_to_remove = np.arange(len(scores)) >= cls.top_k
|
||||
# Unsort the logits array back to its original configuration and
|
||||
# remove tokens we need to remove
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
np.argsort(-scores),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return np.where(indices_to_remove, -np.inf, scores)
|
||||
|
||||
class TailFreeLogitsWarper(LogitsWarper):
|
||||
|
||||
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
tfs = float(tfs)
|
||||
if tfs < 0 or tfs > 1.0:
|
||||
raise ValueError(f"`tfs` has to be a float >= 0 and <= 1, but is {tfs}")
|
||||
self.tfs = tfs
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
class TailFree(Warper):
|
||||
"""
|
||||
Tail free sampling (basically top-p a second time on remaining tokens except
|
||||
it's the "cumulative normalized absolute second finite differences of the
|
||||
softmax probabilities" instead of just the cumulative softmax probabilities)
|
||||
"""
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.filter_value >= 1.0:
|
||||
return scores
|
||||
tfs: float = 1.0
|
||||
|
||||
@classmethod
|
||||
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
|
||||
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||||
probs = sorted_logits.softmax(dim=-1)
|
||||
|
||||
@@ -82,7 +174,7 @@ class TailFreeLogitsWarper(LogitsWarper):
|
||||
normalized_d2_cdf = normalized_d2.cumsum(dim=-1)
|
||||
|
||||
# Remove tokens with CDF value above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = normalized_d2_cdf > self.tfs
|
||||
sorted_indices_to_remove = normalized_d2_cdf > cls.tfs
|
||||
|
||||
# Centre the distribution around the cutoff as in the original implementation of the algorithm
|
||||
sorted_indices_to_remove = torch.cat(
|
||||
@@ -94,32 +186,64 @@ class TailFreeLogitsWarper(LogitsWarper):
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep
|
||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
scores = scores.masked_fill(indices_to_remove, -np.inf)
|
||||
return scores
|
||||
|
||||
@classmethod
|
||||
def jax(cls, scores: jnp.array) -> jnp.array:
|
||||
# Sort in descending order
|
||||
sorted_logits = -np.sort(-scores)
|
||||
|
||||
class TypicalLogitsWarper(LogitsWarper):
|
||||
'''
|
||||
Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf
|
||||
'''
|
||||
# Softmax again
|
||||
probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True)
|
||||
|
||||
def __init__(self, typical: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
typical = float(typical)
|
||||
if typical < 0 or typical > 1.0:
|
||||
raise ValueError(f"`typical` has to be a float >= 0 and <= 1, but is {typical}")
|
||||
self.typical = typical
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
# Calculate the second finite differences of that array (i.e.
|
||||
# calculate the difference array and then calculate the difference
|
||||
# array of the difference array)
|
||||
d2 = np.diff(np.diff(probabilities))
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.filter_value >= 1.0:
|
||||
return scores
|
||||
# Get the absolute values of all those second finite differences
|
||||
d2 = np.abs(d2)
|
||||
|
||||
# Normalize (all elements in the array are divided by the sum of the
|
||||
# array's elements)
|
||||
d2 = d2 / d2.sum(axis=-1, keepdims=True)
|
||||
|
||||
# Get the prefix-sum array
|
||||
cumulative_d2 = np.cumsum(d2, axis=-1)
|
||||
|
||||
# We will remove the tokens with a cumulative normalized absolute
|
||||
# second finite difference larger than the TFS value
|
||||
sorted_indices_to_remove = cumulative_d2 > cls.tfs
|
||||
|
||||
# Don't remove the token with the highest logit
|
||||
sorted_indices_to_remove[0] = False
|
||||
|
||||
# Since the d2 array has two fewer elements than the logits array,
|
||||
# we'll add two extra Trues to the end
|
||||
sorted_indices_to_remove = np.pad(
|
||||
sorted_indices_to_remove,
|
||||
(0, 2),
|
||||
constant_values=True,
|
||||
)
|
||||
# Unsort and remove
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
np.argsort(-scores),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return np.where(indices_to_remove, -np.inf, scores)
|
||||
|
||||
|
||||
class Typical(Warper):
|
||||
"""Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf"""
|
||||
|
||||
typical: float = 1.0
|
||||
|
||||
@classmethod
|
||||
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
|
||||
# Compute softmax probabilities and the natural logarithms of them
|
||||
probs = scores.softmax(dim=-1)
|
||||
log_probs = probs.log()
|
||||
@@ -141,42 +265,261 @@ class TypicalLogitsWarper(LogitsWarper):
|
||||
# threshold)
|
||||
_, sorted_indices = torch.sort(entropy_deviation)
|
||||
sorted_logits = probs.gather(-1, sorted_indices)
|
||||
sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= self.typical
|
||||
sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= cls.typical
|
||||
sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dims=-1)
|
||||
|
||||
min_tokens_to_keep = max(self.min_tokens_to_keep, 1)
|
||||
min_tokens_to_keep = 1
|
||||
# Keep at least min_tokens_to_keep
|
||||
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
scores = scores.masked_fill(indices_to_remove, -np.inf)
|
||||
return scores
|
||||
|
||||
@classmethod
|
||||
def jax(cls, scores: jnp.array) -> jnp.array:
|
||||
# Compute softmax probabilities and the natural logarithms of them
|
||||
probs = jax.nn.softmax(scores)
|
||||
with np.errstate(divide="ignore"):
|
||||
log_probs = np.log(probs)
|
||||
|
||||
class TopALogitsWarper(LogitsWarper):
|
||||
def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
top_a = float(top_a)
|
||||
if top_a < 0 or top_a > 1.0:
|
||||
raise ValueError(f"`top_a` has to be a float >= 0 and <= 1, but is {top_a}")
|
||||
self.top_a = top_a
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
|
||||
# in the set of softmax probabilities of the logits
|
||||
neg_entropy = np.nansum(probs * log_probs, axis=-1, keepdims=True)
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.filter_value >= 1.0:
|
||||
return scores
|
||||
# Determine absolute difference between the negative entropy and the
|
||||
# log probabilities
|
||||
entropy_deviation = np.abs(neg_entropy - log_probs)
|
||||
|
||||
# Keep certain tokens such that the sum of the entropy_deviation of the
|
||||
# kept tokens is the smallest possible value such that the sum of the
|
||||
# softmax probabilities of the kept tokens is at least the threshold
|
||||
# value (by sorting the tokens in ascending order of entropy_deviation
|
||||
# and then keeping the smallest possible number of tokens from the
|
||||
# beginning such that sum of softmax probabilities is at or above the
|
||||
# threshold)
|
||||
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
|
||||
sorted_indices_to_remove = np.cumsum(sorted_logits, axis=-1) >= cls.typical
|
||||
sorted_indices_to_remove = np.roll(sorted_indices_to_remove, 1, axis=-1)
|
||||
sorted_indices_to_remove[0] = False
|
||||
|
||||
# Unsort and remove
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(entropy_deviation),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return np.where(indices_to_remove, -jnp.inf, scores)
|
||||
|
||||
|
||||
class TopA(Warper):
|
||||
"""
|
||||
Top-a (remove all tokens that have softmax probability less than *m^2 where
|
||||
m is the maximum softmax probability)
|
||||
"""
|
||||
|
||||
top_a: float = 0.0
|
||||
|
||||
@classmethod
|
||||
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
|
||||
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||||
probs = sorted_logits.softmax(dim=-1)
|
||||
|
||||
# Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept)
|
||||
probs_max = probs[..., 0, None]
|
||||
sorted_indices_to_remove = probs < probs_max * probs_max * self.top_a
|
||||
sorted_indices_to_remove = probs < probs_max * probs_max * cls.top_a
|
||||
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep
|
||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
scores = scores.masked_fill(indices_to_remove, -np.inf)
|
||||
return scores
|
||||
|
||||
@classmethod
|
||||
def jax(cls, scores: jnp.array) -> jnp.array:
|
||||
# Replace every element in the logits array
|
||||
# with e (Euler's number) to the power of that element, and divide
|
||||
# each element of the new array by the sum of the elements in the
|
||||
# new array
|
||||
probabilities = np.array(jax.nn.softmax(scores), copy=True)
|
||||
# Find the largest probability
|
||||
probs_max = probabilities.max()
|
||||
# Remove tokens
|
||||
return np.where(
|
||||
probabilities < probs_max * probs_max * cls.top_a, -np.inf, scores
|
||||
)
|
||||
|
||||
|
||||
class RepetitionPenalty(Warper):
|
||||
rep_pen: float = 1.0
|
||||
rep_pen_slope: float = 0.0
|
||||
rep_pen_range: int = 0
|
||||
use_alt_rep_pen: bool = False
|
||||
|
||||
@classmethod
|
||||
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
|
||||
cls.rep_pen_range = int(cls.rep_pen_range)
|
||||
clipped_penalty_range = min(input_ids.shape[-1], cls.rep_pen_range)
|
||||
|
||||
if cls.rep_pen != 1.0:
|
||||
if cls.rep_pen_range > 0:
|
||||
if clipped_penalty_range < input_ids.shape[1]:
|
||||
input_ids = input_ids[..., -clipped_penalty_range:]
|
||||
|
||||
if cls.rep_pen_slope != 0:
|
||||
_penalty = (
|
||||
torch.arange(
|
||||
cls.rep_pen_range, dtype=scores.dtype, device=scores.device
|
||||
)
|
||||
/ (cls.rep_pen_range - 1)
|
||||
) * 2.0 - 1
|
||||
_penalty = (cls.rep_pen_slope * _penalty) / (
|
||||
1 + torch.abs(_penalty) * (cls.rep_pen_slope - 1)
|
||||
)
|
||||
_penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (cls.rep_pen - 1)
|
||||
cls.rep_pen = _penalty[..., -clipped_penalty_range:]
|
||||
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
if cls.use_alt_rep_pen:
|
||||
score = score - torch.log(cls.rep_pen)
|
||||
else:
|
||||
score = torch.where(
|
||||
score <= 0, score * cls.rep_pen, score / cls.rep_pen
|
||||
)
|
||||
scores.scatter_(1, input_ids, score)
|
||||
|
||||
return scores
|
||||
|
||||
@classmethod
|
||||
# def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
def jax_static(
|
||||
cls,
|
||||
logits: jnp.array,
|
||||
tokens: jnp.array,
|
||||
generated_index,
|
||||
) -> jnp.array:
|
||||
"""
|
||||
This gets called by generate_loop_fn to apply repetition penalty
|
||||
to the 1D array logits using the provided 1D array of tokens to penalize
|
||||
"""
|
||||
rpslope = jnp.int32(cls.rep_pen_slope)
|
||||
rprange = jnp.int32(cls.rep_pen_range)
|
||||
repetition_penalty = cls.rep_pen
|
||||
|
||||
clipped_rprange = jax.lax.cond(
|
||||
rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange
|
||||
)
|
||||
|
||||
penalty_arange = jnp.roll(
|
||||
jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]),
|
||||
generated_index,
|
||||
axis=-1,
|
||||
)
|
||||
# Make a new array with the same length as the tokens array but with
|
||||
# each element replaced by the value at the corresponding index in the
|
||||
# logits array; e.g.
|
||||
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
|
||||
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
|
||||
penalty_logits = jnp.take(logits, tokens)
|
||||
# Repetition penalty slope
|
||||
def apply_slope(carry):
|
||||
repetition_penalty, rprange = carry
|
||||
_penalty = (penalty_arange / (rprange - 1)) * 2 - 1
|
||||
_penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1))
|
||||
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
|
||||
return _penalty
|
||||
|
||||
repetition_penalty = jax.lax.cond(
|
||||
(rpslope != 0.0)
|
||||
& (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash
|
||||
apply_slope,
|
||||
lambda carry: jnp.full(tokens.shape, carry[0]),
|
||||
(repetition_penalty, rprange),
|
||||
)
|
||||
# Divide positive values by repetition_penalty and multiply negative
|
||||
# values by repetition_penalty (the academic publication that described
|
||||
# this technique actually just only divided, but that would cause tokens
|
||||
# with negative logits to become more likely, which is obviously wrong)
|
||||
if cls.use_alt_rep_pen:
|
||||
penalty_logits = jnp.where(
|
||||
penalty_arange >= 0,
|
||||
penalty_logits - jnp.log(repetition_penalty),
|
||||
penalty_logits,
|
||||
)
|
||||
else:
|
||||
penalty_logits = jnp.where(
|
||||
penalty_arange >= 0,
|
||||
jnp.where(
|
||||
penalty_logits > 0,
|
||||
penalty_logits / repetition_penalty,
|
||||
penalty_logits * repetition_penalty,
|
||||
),
|
||||
penalty_logits,
|
||||
)
|
||||
# Finally, put those penalized logit values back into their original
|
||||
# positions in the logits array
|
||||
return logits.at[tokens].set(penalty_logits)
|
||||
|
||||
@classmethod
|
||||
def jax_dynamic(
|
||||
cls,
|
||||
scores: jnp.array,
|
||||
tokens: jnp.array,
|
||||
generated_index,
|
||||
) -> jnp.array:
|
||||
"""
|
||||
This gets called by generate_loop_fn to apply repetition penalty
|
||||
to the 1D array logits using the provided 1D array of tokens to penalize
|
||||
"""
|
||||
tokens = np.minimum(
|
||||
tokens, tpu_mtj_backend.params["n_vocab"] - 1
|
||||
) # https://github.com/google/jax/issues/3774
|
||||
|
||||
rpslope = np.int32(cls.rep_pen_slope)
|
||||
rprange = np.int32(cls.rep_pen_range)
|
||||
repetition_penalty = cls.rep_pen
|
||||
|
||||
clipped_rprange = rprange if rprange > 0 else tokens.shape[-1]
|
||||
penalty_arange = np.roll(
|
||||
np.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]),
|
||||
generated_index,
|
||||
axis=-1,
|
||||
)
|
||||
# Make a new array with the same length as the tokens array but with
|
||||
# each element replaced by the value at the corresponding index in the
|
||||
# logits array; e.g.
|
||||
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
|
||||
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
|
||||
penalty_logits = np.take(scores, tokens)
|
||||
# Repetition penalty slope
|
||||
if rpslope != 0.0 and rprange > 0:
|
||||
_penalty = (penalty_arange / (rprange - 1)) * 2 - 1
|
||||
_penalty = (rpslope * _penalty) / (1 + np.abs(_penalty) * (rpslope - 1))
|
||||
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
|
||||
repetition_penalty = _penalty
|
||||
# Divide positive values by repetition_penalty and multiply negative
|
||||
# values by repetition_penalty (the academic publication that described
|
||||
# this technique actually just only divided, but that would cause tokens
|
||||
# with negative logits to become more likely, which is obviously wrong)
|
||||
if cls.use_alt_rep_pen:
|
||||
penalty_logits = np.where(
|
||||
penalty_arange >= 0,
|
||||
penalty_logits - np.log(repetition_penalty),
|
||||
penalty_logits,
|
||||
)
|
||||
|
||||
else:
|
||||
penalty_logits = np.where(
|
||||
penalty_arange >= 0,
|
||||
np.where(
|
||||
penalty_logits > 0,
|
||||
penalty_logits / repetition_penalty,
|
||||
penalty_logits * repetition_penalty,
|
||||
),
|
||||
penalty_logits,
|
||||
)
|
||||
# Finally, put those penalized logit values back into their original
|
||||
# positions in the logits array
|
||||
scores[tokens] = penalty_logits
|
||||
return scores
|
||||
|
Reference in New Issue
Block a user