mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: Samplers pt. 1
This commit is contained in:
501
warpers.py
501
warpers.py
@@ -1,4 +1,4 @@
|
|||||||
'''
|
"""
|
||||||
This file is AGPL-licensed.
|
This file is AGPL-licensed.
|
||||||
|
|
||||||
Some of the code in this file is from Clover Edition:
|
Some of the code in this file is from Clover Edition:
|
||||||
@@ -25,54 +25,146 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
SOFTWARE.
|
SOFTWARE.
|
||||||
'''
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Some of the code in this file is also from Hugging Face logitsTransformers:
|
||||||
|
https://github.com/huggingface/transformers
|
||||||
|
|
||||||
|
Transformers is licensed under the Apache-2.0 License. The changes made to this
|
||||||
|
file are mostly porting warper code to the torch methods.
|
||||||
|
"""
|
||||||
|
# Comments mostly taken from tpu_mtj_backend.py
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import LogitsWarper
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
import tpu_mtj_backend
|
||||||
|
|
||||||
|
|
||||||
class AdvancedRepetitionPenaltyLogitsProcessor(LogitsWarper):
|
class Warper:
|
||||||
def __init__(self, *args, **kwargs):
|
@staticmethod
|
||||||
pass
|
def from_id(warper_id: int) -> Warper:
|
||||||
|
return {
|
||||||
|
0: TopK,
|
||||||
|
1: TopA,
|
||||||
|
2: TopP,
|
||||||
|
3: TailFree,
|
||||||
|
4: Typical,
|
||||||
|
5: Temperature,
|
||||||
|
6: RepetitionPenalty,
|
||||||
|
}[warper_id]
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
||||||
self.penalty_range = int(self.penalty_range)
|
|
||||||
clipped_penalty_range = min(input_ids.shape[-1], self.penalty_range)
|
|
||||||
|
|
||||||
if self.penalty != 1.0:
|
class Temperature(Warper):
|
||||||
if self.penalty_range > 0:
|
"""Temperature (just divide the logits by the temperature)"""
|
||||||
if clipped_penalty_range < input_ids.shape[1]:
|
|
||||||
input_ids = input_ids[..., -clipped_penalty_range:]
|
|
||||||
|
|
||||||
if self.penalty_slope != 0:
|
temperature: float = 0.5
|
||||||
_penalty = (torch.arange(self.penalty_range, dtype=scores.dtype, device=scores.device)/(self.penalty_range - 1)) * 2. - 1
|
|
||||||
_penalty = (self.penalty_slope * _penalty) / (1 + torch.abs(_penalty) * (self.penalty_slope - 1))
|
|
||||||
_penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (self.penalty - 1)
|
|
||||||
self.penalty = _penalty[..., -clipped_penalty_range:]
|
|
||||||
|
|
||||||
score = torch.gather(scores, 1, input_ids)
|
@classmethod
|
||||||
if self.use_alt_rep_pen:
|
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
|
||||||
score = score - torch.log(self.penalty)
|
return scores / cls.temperature
|
||||||
else:
|
|
||||||
score = torch.where(score <= 0, score * self.penalty, score / self.penalty)
|
|
||||||
scores.scatter_(1, input_ids, score)
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def jax(cls, scores: jnp.array) -> jnp.array:
|
||||||
|
return scores / cls.value
|
||||||
|
|
||||||
|
|
||||||
|
class TopP(Warper):
|
||||||
|
"""
|
||||||
|
Top-p (after sorting the remaining tokens again in descending order of
|
||||||
|
logit, remove the ones that have cumulative softmax probability
|
||||||
|
greater than p)
|
||||||
|
"""
|
||||||
|
|
||||||
|
top_p: float = 0.9
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
|
||||||
|
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
|
||||||
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||||
|
|
||||||
|
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||||
|
sorted_indices_to_remove = cumulative_probs <= (1 - cls.value)
|
||||||
|
|
||||||
|
# scatter sorted tensors to original indexing
|
||||||
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||||
|
1, sorted_indices, sorted_indices_to_remove
|
||||||
|
)
|
||||||
|
return scores.masked_fill(indices_to_remove, -np.inf)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def jax(cls, scores: jnp.array) -> jnp.array:
|
||||||
|
# Sort the logits array in descending order, replace every element
|
||||||
|
# with e (Euler's number) to the power of that element, and divide
|
||||||
|
# each element of the new array by the sum of the elements in the
|
||||||
|
# new array
|
||||||
|
sorted_logits = -jnp.sort(-scores)
|
||||||
|
probabilities = jax.nn.softmax(sorted_logits)
|
||||||
|
# Calculate cumulative_probabilities as the prefix-sum array of
|
||||||
|
# probabilities
|
||||||
|
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1)
|
||||||
|
# We want to remove tokens with cumulative probability higher
|
||||||
|
# than top_p
|
||||||
|
sorted_indices_to_remove = cumulative_probabilities > cls.value
|
||||||
|
# Don't ever remove the token with the highest logit, even if
|
||||||
|
# the probability is higher than top_p
|
||||||
|
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||||
|
# Unsort and remove
|
||||||
|
_, indices_to_remove = jax.lax.sort_key_val(
|
||||||
|
jnp.argsort(-scores),
|
||||||
|
sorted_indices_to_remove,
|
||||||
|
)
|
||||||
|
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||||
|
|
||||||
|
|
||||||
|
class TopK(Warper):
|
||||||
|
"""
|
||||||
|
Top-k (keep only the k tokens with the highest logits and remove the rest,
|
||||||
|
by setting their logits to negative infinity)
|
||||||
|
"""
|
||||||
|
|
||||||
|
top_k: int = 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
|
||||||
|
top_k = min(max(cls.top_k, 1), scores.size(-1)) # Safety check
|
||||||
|
# Remove all tokens with a probability less than the last token of the top-k
|
||||||
|
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
|
||||||
|
scores = scores.masked_fill(indices_to_remove, -np.inf)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def jax(cls, scores: jnp.array) -> jnp.array:
|
||||||
|
# After sorting the logits array in descending order,
|
||||||
|
# sorted_indices_to_remove is a 1D array that is True for tokens
|
||||||
|
# in the sorted logits array we want to remove and False for ones
|
||||||
|
# we want to keep, in this case the first top_k elements will be
|
||||||
|
# False and the rest will be True
|
||||||
|
sorted_indices_to_remove = np.arange(len(scores)) >= cls.top_k
|
||||||
|
# Unsort the logits array back to its original configuration and
|
||||||
|
# remove tokens we need to remove
|
||||||
|
_, indices_to_remove = jax.lax.sort_key_val(
|
||||||
|
np.argsort(-scores),
|
||||||
|
sorted_indices_to_remove,
|
||||||
|
)
|
||||||
|
return np.where(indices_to_remove, -np.inf, scores)
|
||||||
|
|
||||||
class TailFreeLogitsWarper(LogitsWarper):
|
|
||||||
|
|
||||||
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
class TailFree(Warper):
|
||||||
tfs = float(tfs)
|
"""
|
||||||
if tfs < 0 or tfs > 1.0:
|
Tail free sampling (basically top-p a second time on remaining tokens except
|
||||||
raise ValueError(f"`tfs` has to be a float >= 0 and <= 1, but is {tfs}")
|
it's the "cumulative normalized absolute second finite differences of the
|
||||||
self.tfs = tfs
|
softmax probabilities" instead of just the cumulative softmax probabilities)
|
||||||
self.filter_value = filter_value
|
"""
|
||||||
self.min_tokens_to_keep = min_tokens_to_keep
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
tfs: float = 1.0
|
||||||
if self.filter_value >= 1.0:
|
|
||||||
return scores
|
@classmethod
|
||||||
|
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
|
||||||
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||||||
probs = sorted_logits.softmax(dim=-1)
|
probs = sorted_logits.softmax(dim=-1)
|
||||||
|
|
||||||
@@ -82,7 +174,7 @@ class TailFreeLogitsWarper(LogitsWarper):
|
|||||||
normalized_d2_cdf = normalized_d2.cumsum(dim=-1)
|
normalized_d2_cdf = normalized_d2.cumsum(dim=-1)
|
||||||
|
|
||||||
# Remove tokens with CDF value above the threshold (token with 0 are kept)
|
# Remove tokens with CDF value above the threshold (token with 0 are kept)
|
||||||
sorted_indices_to_remove = normalized_d2_cdf > self.tfs
|
sorted_indices_to_remove = normalized_d2_cdf > cls.tfs
|
||||||
|
|
||||||
# Centre the distribution around the cutoff as in the original implementation of the algorithm
|
# Centre the distribution around the cutoff as in the original implementation of the algorithm
|
||||||
sorted_indices_to_remove = torch.cat(
|
sorted_indices_to_remove = torch.cat(
|
||||||
@@ -94,32 +186,64 @@ class TailFreeLogitsWarper(LogitsWarper):
|
|||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.min_tokens_to_keep > 1:
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||||
# Keep at least min_tokens_to_keep
|
1, sorted_indices, sorted_indices_to_remove
|
||||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
)
|
||||||
|
scores = scores.masked_fill(indices_to_remove, -np.inf)
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
|
||||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def jax(cls, scores: jnp.array) -> jnp.array:
|
||||||
|
# Sort in descending order
|
||||||
|
sorted_logits = -np.sort(-scores)
|
||||||
|
|
||||||
class TypicalLogitsWarper(LogitsWarper):
|
# Softmax again
|
||||||
'''
|
probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True)
|
||||||
Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf
|
|
||||||
'''
|
|
||||||
|
|
||||||
def __init__(self, typical: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
# Calculate the second finite differences of that array (i.e.
|
||||||
typical = float(typical)
|
# calculate the difference array and then calculate the difference
|
||||||
if typical < 0 or typical > 1.0:
|
# array of the difference array)
|
||||||
raise ValueError(f"`typical` has to be a float >= 0 and <= 1, but is {typical}")
|
d2 = np.diff(np.diff(probabilities))
|
||||||
self.typical = typical
|
|
||||||
self.filter_value = filter_value
|
|
||||||
self.min_tokens_to_keep = min_tokens_to_keep
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
# Get the absolute values of all those second finite differences
|
||||||
if self.filter_value >= 1.0:
|
d2 = np.abs(d2)
|
||||||
return scores
|
|
||||||
|
|
||||||
|
# Normalize (all elements in the array are divided by the sum of the
|
||||||
|
# array's elements)
|
||||||
|
d2 = d2 / d2.sum(axis=-1, keepdims=True)
|
||||||
|
|
||||||
|
# Get the prefix-sum array
|
||||||
|
cumulative_d2 = np.cumsum(d2, axis=-1)
|
||||||
|
|
||||||
|
# We will remove the tokens with a cumulative normalized absolute
|
||||||
|
# second finite difference larger than the TFS value
|
||||||
|
sorted_indices_to_remove = cumulative_d2 > cls.tfs
|
||||||
|
|
||||||
|
# Don't remove the token with the highest logit
|
||||||
|
sorted_indices_to_remove[0] = False
|
||||||
|
|
||||||
|
# Since the d2 array has two fewer elements than the logits array,
|
||||||
|
# we'll add two extra Trues to the end
|
||||||
|
sorted_indices_to_remove = np.pad(
|
||||||
|
sorted_indices_to_remove,
|
||||||
|
(0, 2),
|
||||||
|
constant_values=True,
|
||||||
|
)
|
||||||
|
# Unsort and remove
|
||||||
|
_, indices_to_remove = jax.lax.sort_key_val(
|
||||||
|
np.argsort(-scores),
|
||||||
|
sorted_indices_to_remove,
|
||||||
|
)
|
||||||
|
return np.where(indices_to_remove, -np.inf, scores)
|
||||||
|
|
||||||
|
|
||||||
|
class Typical(Warper):
|
||||||
|
"""Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf"""
|
||||||
|
|
||||||
|
typical: float = 1.0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
|
||||||
# Compute softmax probabilities and the natural logarithms of them
|
# Compute softmax probabilities and the natural logarithms of them
|
||||||
probs = scores.softmax(dim=-1)
|
probs = scores.softmax(dim=-1)
|
||||||
log_probs = probs.log()
|
log_probs = probs.log()
|
||||||
@@ -141,42 +265,261 @@ class TypicalLogitsWarper(LogitsWarper):
|
|||||||
# threshold)
|
# threshold)
|
||||||
_, sorted_indices = torch.sort(entropy_deviation)
|
_, sorted_indices = torch.sort(entropy_deviation)
|
||||||
sorted_logits = probs.gather(-1, sorted_indices)
|
sorted_logits = probs.gather(-1, sorted_indices)
|
||||||
sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= self.typical
|
sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= cls.typical
|
||||||
sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dims=-1)
|
sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dims=-1)
|
||||||
|
|
||||||
min_tokens_to_keep = max(self.min_tokens_to_keep, 1)
|
min_tokens_to_keep = 1
|
||||||
# Keep at least min_tokens_to_keep
|
# Keep at least min_tokens_to_keep
|
||||||
sorted_indices_to_remove[..., : min_tokens_to_keep] = 0
|
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
||||||
|
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
1, sorted_indices, sorted_indices_to_remove
|
||||||
|
)
|
||||||
|
scores = scores.masked_fill(indices_to_remove, -np.inf)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def jax(cls, scores: jnp.array) -> jnp.array:
|
||||||
|
# Compute softmax probabilities and the natural logarithms of them
|
||||||
|
probs = jax.nn.softmax(scores)
|
||||||
|
with np.errstate(divide="ignore"):
|
||||||
|
log_probs = np.log(probs)
|
||||||
|
|
||||||
class TopALogitsWarper(LogitsWarper):
|
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
|
||||||
def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
# in the set of softmax probabilities of the logits
|
||||||
top_a = float(top_a)
|
neg_entropy = np.nansum(probs * log_probs, axis=-1, keepdims=True)
|
||||||
if top_a < 0 or top_a > 1.0:
|
|
||||||
raise ValueError(f"`top_a` has to be a float >= 0 and <= 1, but is {top_a}")
|
|
||||||
self.top_a = top_a
|
|
||||||
self.filter_value = filter_value
|
|
||||||
self.min_tokens_to_keep = min_tokens_to_keep
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
# Determine absolute difference between the negative entropy and the
|
||||||
if self.filter_value >= 1.0:
|
# log probabilities
|
||||||
return scores
|
entropy_deviation = np.abs(neg_entropy - log_probs)
|
||||||
|
|
||||||
|
# Keep certain tokens such that the sum of the entropy_deviation of the
|
||||||
|
# kept tokens is the smallest possible value such that the sum of the
|
||||||
|
# softmax probabilities of the kept tokens is at least the threshold
|
||||||
|
# value (by sorting the tokens in ascending order of entropy_deviation
|
||||||
|
# and then keeping the smallest possible number of tokens from the
|
||||||
|
# beginning such that sum of softmax probabilities is at or above the
|
||||||
|
# threshold)
|
||||||
|
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
|
||||||
|
sorted_indices_to_remove = np.cumsum(sorted_logits, axis=-1) >= cls.typical
|
||||||
|
sorted_indices_to_remove = np.roll(sorted_indices_to_remove, 1, axis=-1)
|
||||||
|
sorted_indices_to_remove[0] = False
|
||||||
|
|
||||||
|
# Unsort and remove
|
||||||
|
_, indices_to_remove = jax.lax.sort_key_val(
|
||||||
|
jnp.argsort(entropy_deviation),
|
||||||
|
sorted_indices_to_remove,
|
||||||
|
)
|
||||||
|
return np.where(indices_to_remove, -jnp.inf, scores)
|
||||||
|
|
||||||
|
|
||||||
|
class TopA(Warper):
|
||||||
|
"""
|
||||||
|
Top-a (remove all tokens that have softmax probability less than *m^2 where
|
||||||
|
m is the maximum softmax probability)
|
||||||
|
"""
|
||||||
|
|
||||||
|
top_a: float = 0.0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
|
||||||
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||||||
probs = sorted_logits.softmax(dim=-1)
|
probs = sorted_logits.softmax(dim=-1)
|
||||||
|
|
||||||
# Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept)
|
# Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept)
|
||||||
probs_max = probs[..., 0, None]
|
probs_max = probs[..., 0, None]
|
||||||
sorted_indices_to_remove = probs < probs_max * probs_max * self.top_a
|
sorted_indices_to_remove = probs < probs_max * probs_max * cls.top_a
|
||||||
|
|
||||||
if self.min_tokens_to_keep > 1:
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||||
# Keep at least min_tokens_to_keep
|
1, sorted_indices, sorted_indices_to_remove
|
||||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
)
|
||||||
|
scores = scores.masked_fill(indices_to_remove, -np.inf)
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
return scores
|
||||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
|
||||||
|
@classmethod
|
||||||
|
def jax(cls, scores: jnp.array) -> jnp.array:
|
||||||
|
# Replace every element in the logits array
|
||||||
|
# with e (Euler's number) to the power of that element, and divide
|
||||||
|
# each element of the new array by the sum of the elements in the
|
||||||
|
# new array
|
||||||
|
probabilities = np.array(jax.nn.softmax(scores), copy=True)
|
||||||
|
# Find the largest probability
|
||||||
|
probs_max = probabilities.max()
|
||||||
|
# Remove tokens
|
||||||
|
return np.where(
|
||||||
|
probabilities < probs_max * probs_max * cls.top_a, -np.inf, scores
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RepetitionPenalty(Warper):
|
||||||
|
rep_pen: float = 1.0
|
||||||
|
rep_pen_slope: float = 0.0
|
||||||
|
rep_pen_range: int = 0
|
||||||
|
use_alt_rep_pen: bool = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
|
||||||
|
cls.rep_pen_range = int(cls.rep_pen_range)
|
||||||
|
clipped_penalty_range = min(input_ids.shape[-1], cls.rep_pen_range)
|
||||||
|
|
||||||
|
if cls.rep_pen != 1.0:
|
||||||
|
if cls.rep_pen_range > 0:
|
||||||
|
if clipped_penalty_range < input_ids.shape[1]:
|
||||||
|
input_ids = input_ids[..., -clipped_penalty_range:]
|
||||||
|
|
||||||
|
if cls.rep_pen_slope != 0:
|
||||||
|
_penalty = (
|
||||||
|
torch.arange(
|
||||||
|
cls.rep_pen_range, dtype=scores.dtype, device=scores.device
|
||||||
|
)
|
||||||
|
/ (cls.rep_pen_range - 1)
|
||||||
|
) * 2.0 - 1
|
||||||
|
_penalty = (cls.rep_pen_slope * _penalty) / (
|
||||||
|
1 + torch.abs(_penalty) * (cls.rep_pen_slope - 1)
|
||||||
|
)
|
||||||
|
_penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (cls.rep_pen - 1)
|
||||||
|
cls.rep_pen = _penalty[..., -clipped_penalty_range:]
|
||||||
|
|
||||||
|
score = torch.gather(scores, 1, input_ids)
|
||||||
|
if cls.use_alt_rep_pen:
|
||||||
|
score = score - torch.log(cls.rep_pen)
|
||||||
|
else:
|
||||||
|
score = torch.where(
|
||||||
|
score <= 0, score * cls.rep_pen, score / cls.rep_pen
|
||||||
|
)
|
||||||
|
scores.scatter_(1, input_ids, score)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
# def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||||
|
def jax_static(
|
||||||
|
cls,
|
||||||
|
logits: jnp.array,
|
||||||
|
tokens: jnp.array,
|
||||||
|
generated_index,
|
||||||
|
) -> jnp.array:
|
||||||
|
"""
|
||||||
|
This gets called by generate_loop_fn to apply repetition penalty
|
||||||
|
to the 1D array logits using the provided 1D array of tokens to penalize
|
||||||
|
"""
|
||||||
|
rpslope = jnp.int32(cls.rep_pen_slope)
|
||||||
|
rprange = jnp.int32(cls.rep_pen_range)
|
||||||
|
repetition_penalty = cls.rep_pen
|
||||||
|
|
||||||
|
clipped_rprange = jax.lax.cond(
|
||||||
|
rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange
|
||||||
|
)
|
||||||
|
|
||||||
|
penalty_arange = jnp.roll(
|
||||||
|
jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]),
|
||||||
|
generated_index,
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
# Make a new array with the same length as the tokens array but with
|
||||||
|
# each element replaced by the value at the corresponding index in the
|
||||||
|
# logits array; e.g.
|
||||||
|
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
|
||||||
|
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
|
||||||
|
penalty_logits = jnp.take(logits, tokens)
|
||||||
|
# Repetition penalty slope
|
||||||
|
def apply_slope(carry):
|
||||||
|
repetition_penalty, rprange = carry
|
||||||
|
_penalty = (penalty_arange / (rprange - 1)) * 2 - 1
|
||||||
|
_penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1))
|
||||||
|
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
|
||||||
|
return _penalty
|
||||||
|
|
||||||
|
repetition_penalty = jax.lax.cond(
|
||||||
|
(rpslope != 0.0)
|
||||||
|
& (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash
|
||||||
|
apply_slope,
|
||||||
|
lambda carry: jnp.full(tokens.shape, carry[0]),
|
||||||
|
(repetition_penalty, rprange),
|
||||||
|
)
|
||||||
|
# Divide positive values by repetition_penalty and multiply negative
|
||||||
|
# values by repetition_penalty (the academic publication that described
|
||||||
|
# this technique actually just only divided, but that would cause tokens
|
||||||
|
# with negative logits to become more likely, which is obviously wrong)
|
||||||
|
if cls.use_alt_rep_pen:
|
||||||
|
penalty_logits = jnp.where(
|
||||||
|
penalty_arange >= 0,
|
||||||
|
penalty_logits - jnp.log(repetition_penalty),
|
||||||
|
penalty_logits,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
penalty_logits = jnp.where(
|
||||||
|
penalty_arange >= 0,
|
||||||
|
jnp.where(
|
||||||
|
penalty_logits > 0,
|
||||||
|
penalty_logits / repetition_penalty,
|
||||||
|
penalty_logits * repetition_penalty,
|
||||||
|
),
|
||||||
|
penalty_logits,
|
||||||
|
)
|
||||||
|
# Finally, put those penalized logit values back into their original
|
||||||
|
# positions in the logits array
|
||||||
|
return logits.at[tokens].set(penalty_logits)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def jax_dynamic(
|
||||||
|
cls,
|
||||||
|
scores: jnp.array,
|
||||||
|
tokens: jnp.array,
|
||||||
|
generated_index,
|
||||||
|
) -> jnp.array:
|
||||||
|
"""
|
||||||
|
This gets called by generate_loop_fn to apply repetition penalty
|
||||||
|
to the 1D array logits using the provided 1D array of tokens to penalize
|
||||||
|
"""
|
||||||
|
tokens = np.minimum(
|
||||||
|
tokens, tpu_mtj_backend.params["n_vocab"] - 1
|
||||||
|
) # https://github.com/google/jax/issues/3774
|
||||||
|
|
||||||
|
rpslope = np.int32(cls.rep_pen_slope)
|
||||||
|
rprange = np.int32(cls.rep_pen_range)
|
||||||
|
repetition_penalty = cls.rep_pen
|
||||||
|
|
||||||
|
clipped_rprange = rprange if rprange > 0 else tokens.shape[-1]
|
||||||
|
penalty_arange = np.roll(
|
||||||
|
np.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]),
|
||||||
|
generated_index,
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
# Make a new array with the same length as the tokens array but with
|
||||||
|
# each element replaced by the value at the corresponding index in the
|
||||||
|
# logits array; e.g.
|
||||||
|
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
|
||||||
|
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
|
||||||
|
penalty_logits = np.take(scores, tokens)
|
||||||
|
# Repetition penalty slope
|
||||||
|
if rpslope != 0.0 and rprange > 0:
|
||||||
|
_penalty = (penalty_arange / (rprange - 1)) * 2 - 1
|
||||||
|
_penalty = (rpslope * _penalty) / (1 + np.abs(_penalty) * (rpslope - 1))
|
||||||
|
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
|
||||||
|
repetition_penalty = _penalty
|
||||||
|
# Divide positive values by repetition_penalty and multiply negative
|
||||||
|
# values by repetition_penalty (the academic publication that described
|
||||||
|
# this technique actually just only divided, but that would cause tokens
|
||||||
|
# with negative logits to become more likely, which is obviously wrong)
|
||||||
|
if cls.use_alt_rep_pen:
|
||||||
|
penalty_logits = np.where(
|
||||||
|
penalty_arange >= 0,
|
||||||
|
penalty_logits - np.log(repetition_penalty),
|
||||||
|
penalty_logits,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
penalty_logits = np.where(
|
||||||
|
penalty_arange >= 0,
|
||||||
|
np.where(
|
||||||
|
penalty_logits > 0,
|
||||||
|
penalty_logits / repetition_penalty,
|
||||||
|
penalty_logits * repetition_penalty,
|
||||||
|
),
|
||||||
|
penalty_logits,
|
||||||
|
)
|
||||||
|
# Finally, put those penalized logit values back into their original
|
||||||
|
# positions in the logits array
|
||||||
|
scores[tokens] = penalty_logits
|
||||||
return scores
|
return scores
|
||||||
|
Reference in New Issue
Block a user