Model: Samplers pt. 1

This commit is contained in:
somebody
2023-02-26 16:09:22 -06:00
parent f771ae38cf
commit f882979c88

View File

@@ -1,4 +1,4 @@
'''
"""
This file is AGPL-licensed.
Some of the code in this file is from Clover Edition:
@@ -25,54 +25,146 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
'''
---
Some of the code in this file is also from Hugging Face logitsTransformers:
https://github.com/huggingface/transformers
Transformers is licensed under the Apache-2.0 License. The changes made to this
file are mostly porting warper code to the torch methods.
"""
# Comments mostly taken from tpu_mtj_backend.py
from __future__ import annotations
import torch
from transformers import LogitsWarper
import jax
import jax.numpy as jnp
import numpy as np
import tpu_mtj_backend
class AdvancedRepetitionPenaltyLogitsProcessor(LogitsWarper):
def __init__(self, *args, **kwargs):
pass
class Warper:
@staticmethod
def from_id(warper_id: int) -> Warper:
return {
0: TopK,
1: TopA,
2: TopP,
3: TailFree,
4: Typical,
5: Temperature,
6: RepetitionPenalty,
}[warper_id]
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
self.penalty_range = int(self.penalty_range)
clipped_penalty_range = min(input_ids.shape[-1], self.penalty_range)
if self.penalty != 1.0:
if self.penalty_range > 0:
if clipped_penalty_range < input_ids.shape[1]:
input_ids = input_ids[..., -clipped_penalty_range:]
class Temperature(Warper):
"""Temperature (just divide the logits by the temperature)"""
if self.penalty_slope != 0:
_penalty = (torch.arange(self.penalty_range, dtype=scores.dtype, device=scores.device)/(self.penalty_range - 1)) * 2. - 1
_penalty = (self.penalty_slope * _penalty) / (1 + torch.abs(_penalty) * (self.penalty_slope - 1))
_penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (self.penalty - 1)
self.penalty = _penalty[..., -clipped_penalty_range:]
temperature: float = 0.5
score = torch.gather(scores, 1, input_ids)
if self.use_alt_rep_pen:
score = score - torch.log(self.penalty)
else:
score = torch.where(score <= 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, input_ids, score)
@classmethod
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
return scores / cls.temperature
@classmethod
def jax(cls, scores: jnp.array) -> jnp.array:
return scores / cls.value
class TopP(Warper):
"""
Top-p (after sorting the remaining tokens again in descending order of
logit, remove the ones that have cumulative softmax probability
greater than p)
"""
top_p: float = 0.9
@classmethod
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - cls.value)
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
return scores.masked_fill(indices_to_remove, -np.inf)
@classmethod
def jax(cls, scores: jnp.array) -> jnp.array:
# Sort the logits array in descending order, replace every element
# with e (Euler's number) to the power of that element, and divide
# each element of the new array by the sum of the elements in the
# new array
sorted_logits = -jnp.sort(-scores)
probabilities = jax.nn.softmax(sorted_logits)
# Calculate cumulative_probabilities as the prefix-sum array of
# probabilities
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1)
# We want to remove tokens with cumulative probability higher
# than top_p
sorted_indices_to_remove = cumulative_probabilities > cls.value
# Don't ever remove the token with the highest logit, even if
# the probability is higher than top_p
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
# Unsort and remove
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(-scores),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, scores)
class TopK(Warper):
"""
Top-k (keep only the k tokens with the highest logits and remove the rest,
by setting their logits to negative infinity)
"""
top_k: int = 0
@classmethod
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
top_k = min(max(cls.top_k, 1), scores.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
scores = scores.masked_fill(indices_to_remove, -np.inf)
return scores
@classmethod
def jax(cls, scores: jnp.array) -> jnp.array:
# After sorting the logits array in descending order,
# sorted_indices_to_remove is a 1D array that is True for tokens
# in the sorted logits array we want to remove and False for ones
# we want to keep, in this case the first top_k elements will be
# False and the rest will be True
sorted_indices_to_remove = np.arange(len(scores)) >= cls.top_k
# Unsort the logits array back to its original configuration and
# remove tokens we need to remove
_, indices_to_remove = jax.lax.sort_key_val(
np.argsort(-scores),
sorted_indices_to_remove,
)
return np.where(indices_to_remove, -np.inf, scores)
class TailFreeLogitsWarper(LogitsWarper):
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
tfs = float(tfs)
if tfs < 0 or tfs > 1.0:
raise ValueError(f"`tfs` has to be a float >= 0 and <= 1, but is {tfs}")
self.tfs = tfs
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
class TailFree(Warper):
"""
Tail free sampling (basically top-p a second time on remaining tokens except
it's the "cumulative normalized absolute second finite differences of the
softmax probabilities" instead of just the cumulative softmax probabilities)
"""
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self.filter_value >= 1.0:
return scores
tfs: float = 1.0
@classmethod
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
probs = sorted_logits.softmax(dim=-1)
@@ -82,7 +174,7 @@ class TailFreeLogitsWarper(LogitsWarper):
normalized_d2_cdf = normalized_d2.cumsum(dim=-1)
# Remove tokens with CDF value above the threshold (token with 0 are kept)
sorted_indices_to_remove = normalized_d2_cdf > self.tfs
sorted_indices_to_remove = normalized_d2_cdf > cls.tfs
# Centre the distribution around the cutoff as in the original implementation of the algorithm
sorted_indices_to_remove = torch.cat(
@@ -94,32 +186,64 @@ class TailFreeLogitsWarper(LogitsWarper):
dim=-1,
)
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
scores = scores.masked_fill(indices_to_remove, -np.inf)
return scores
@classmethod
def jax(cls, scores: jnp.array) -> jnp.array:
# Sort in descending order
sorted_logits = -np.sort(-scores)
class TypicalLogitsWarper(LogitsWarper):
'''
Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf
'''
# Softmax again
probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True)
def __init__(self, typical: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
typical = float(typical)
if typical < 0 or typical > 1.0:
raise ValueError(f"`typical` has to be a float >= 0 and <= 1, but is {typical}")
self.typical = typical
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
# Calculate the second finite differences of that array (i.e.
# calculate the difference array and then calculate the difference
# array of the difference array)
d2 = np.diff(np.diff(probabilities))
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self.filter_value >= 1.0:
return scores
# Get the absolute values of all those second finite differences
d2 = np.abs(d2)
# Normalize (all elements in the array are divided by the sum of the
# array's elements)
d2 = d2 / d2.sum(axis=-1, keepdims=True)
# Get the prefix-sum array
cumulative_d2 = np.cumsum(d2, axis=-1)
# We will remove the tokens with a cumulative normalized absolute
# second finite difference larger than the TFS value
sorted_indices_to_remove = cumulative_d2 > cls.tfs
# Don't remove the token with the highest logit
sorted_indices_to_remove[0] = False
# Since the d2 array has two fewer elements than the logits array,
# we'll add two extra Trues to the end
sorted_indices_to_remove = np.pad(
sorted_indices_to_remove,
(0, 2),
constant_values=True,
)
# Unsort and remove
_, indices_to_remove = jax.lax.sort_key_val(
np.argsort(-scores),
sorted_indices_to_remove,
)
return np.where(indices_to_remove, -np.inf, scores)
class Typical(Warper):
"""Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf"""
typical: float = 1.0
@classmethod
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
# Compute softmax probabilities and the natural logarithms of them
probs = scores.softmax(dim=-1)
log_probs = probs.log()
@@ -141,42 +265,261 @@ class TypicalLogitsWarper(LogitsWarper):
# threshold)
_, sorted_indices = torch.sort(entropy_deviation)
sorted_logits = probs.gather(-1, sorted_indices)
sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= self.typical
sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= cls.typical
sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dims=-1)
min_tokens_to_keep = max(self.min_tokens_to_keep, 1)
min_tokens_to_keep = 1
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., : min_tokens_to_keep] = 0
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
scores = scores.masked_fill(indices_to_remove, -np.inf)
return scores
@classmethod
def jax(cls, scores: jnp.array) -> jnp.array:
# Compute softmax probabilities and the natural logarithms of them
probs = jax.nn.softmax(scores)
with np.errstate(divide="ignore"):
log_probs = np.log(probs)
class TopALogitsWarper(LogitsWarper):
def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
top_a = float(top_a)
if top_a < 0 or top_a > 1.0:
raise ValueError(f"`top_a` has to be a float >= 0 and <= 1, but is {top_a}")
self.top_a = top_a
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
# in the set of softmax probabilities of the logits
neg_entropy = np.nansum(probs * log_probs, axis=-1, keepdims=True)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self.filter_value >= 1.0:
return scores
# Determine absolute difference between the negative entropy and the
# log probabilities
entropy_deviation = np.abs(neg_entropy - log_probs)
# Keep certain tokens such that the sum of the entropy_deviation of the
# kept tokens is the smallest possible value such that the sum of the
# softmax probabilities of the kept tokens is at least the threshold
# value (by sorting the tokens in ascending order of entropy_deviation
# and then keeping the smallest possible number of tokens from the
# beginning such that sum of softmax probabilities is at or above the
# threshold)
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
sorted_indices_to_remove = np.cumsum(sorted_logits, axis=-1) >= cls.typical
sorted_indices_to_remove = np.roll(sorted_indices_to_remove, 1, axis=-1)
sorted_indices_to_remove[0] = False
# Unsort and remove
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(entropy_deviation),
sorted_indices_to_remove,
)
return np.where(indices_to_remove, -jnp.inf, scores)
class TopA(Warper):
"""
Top-a (remove all tokens that have softmax probability less than *m^2 where
m is the maximum softmax probability)
"""
top_a: float = 0.0
@classmethod
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
probs = sorted_logits.softmax(dim=-1)
# Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept)
probs_max = probs[..., 0, None]
sorted_indices_to_remove = probs < probs_max * probs_max * self.top_a
sorted_indices_to_remove = probs < probs_max * probs_max * cls.top_a
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
scores = scores.masked_fill(indices_to_remove, -np.inf)
return scores
@classmethod
def jax(cls, scores: jnp.array) -> jnp.array:
# Replace every element in the logits array
# with e (Euler's number) to the power of that element, and divide
# each element of the new array by the sum of the elements in the
# new array
probabilities = np.array(jax.nn.softmax(scores), copy=True)
# Find the largest probability
probs_max = probabilities.max()
# Remove tokens
return np.where(
probabilities < probs_max * probs_max * cls.top_a, -np.inf, scores
)
class RepetitionPenalty(Warper):
rep_pen: float = 1.0
rep_pen_slope: float = 0.0
rep_pen_range: int = 0
use_alt_rep_pen: bool = False
@classmethod
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
cls.rep_pen_range = int(cls.rep_pen_range)
clipped_penalty_range = min(input_ids.shape[-1], cls.rep_pen_range)
if cls.rep_pen != 1.0:
if cls.rep_pen_range > 0:
if clipped_penalty_range < input_ids.shape[1]:
input_ids = input_ids[..., -clipped_penalty_range:]
if cls.rep_pen_slope != 0:
_penalty = (
torch.arange(
cls.rep_pen_range, dtype=scores.dtype, device=scores.device
)
/ (cls.rep_pen_range - 1)
) * 2.0 - 1
_penalty = (cls.rep_pen_slope * _penalty) / (
1 + torch.abs(_penalty) * (cls.rep_pen_slope - 1)
)
_penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (cls.rep_pen - 1)
cls.rep_pen = _penalty[..., -clipped_penalty_range:]
score = torch.gather(scores, 1, input_ids)
if cls.use_alt_rep_pen:
score = score - torch.log(cls.rep_pen)
else:
score = torch.where(
score <= 0, score * cls.rep_pen, score / cls.rep_pen
)
scores.scatter_(1, input_ids, score)
return scores
@classmethod
# def jax_static(cls, scores: jnp.array) -> jnp.array:
def jax_static(
cls,
logits: jnp.array,
tokens: jnp.array,
generated_index,
) -> jnp.array:
"""
This gets called by generate_loop_fn to apply repetition penalty
to the 1D array logits using the provided 1D array of tokens to penalize
"""
rpslope = jnp.int32(cls.rep_pen_slope)
rprange = jnp.int32(cls.rep_pen_range)
repetition_penalty = cls.rep_pen
clipped_rprange = jax.lax.cond(
rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange
)
penalty_arange = jnp.roll(
jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]),
generated_index,
axis=-1,
)
# Make a new array with the same length as the tokens array but with
# each element replaced by the value at the corresponding index in the
# logits array; e.g.
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
penalty_logits = jnp.take(logits, tokens)
# Repetition penalty slope
def apply_slope(carry):
repetition_penalty, rprange = carry
_penalty = (penalty_arange / (rprange - 1)) * 2 - 1
_penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1))
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
return _penalty
repetition_penalty = jax.lax.cond(
(rpslope != 0.0)
& (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash
apply_slope,
lambda carry: jnp.full(tokens.shape, carry[0]),
(repetition_penalty, rprange),
)
# Divide positive values by repetition_penalty and multiply negative
# values by repetition_penalty (the academic publication that described
# this technique actually just only divided, but that would cause tokens
# with negative logits to become more likely, which is obviously wrong)
if cls.use_alt_rep_pen:
penalty_logits = jnp.where(
penalty_arange >= 0,
penalty_logits - jnp.log(repetition_penalty),
penalty_logits,
)
else:
penalty_logits = jnp.where(
penalty_arange >= 0,
jnp.where(
penalty_logits > 0,
penalty_logits / repetition_penalty,
penalty_logits * repetition_penalty,
),
penalty_logits,
)
# Finally, put those penalized logit values back into their original
# positions in the logits array
return logits.at[tokens].set(penalty_logits)
@classmethod
def jax_dynamic(
cls,
scores: jnp.array,
tokens: jnp.array,
generated_index,
) -> jnp.array:
"""
This gets called by generate_loop_fn to apply repetition penalty
to the 1D array logits using the provided 1D array of tokens to penalize
"""
tokens = np.minimum(
tokens, tpu_mtj_backend.params["n_vocab"] - 1
) # https://github.com/google/jax/issues/3774
rpslope = np.int32(cls.rep_pen_slope)
rprange = np.int32(cls.rep_pen_range)
repetition_penalty = cls.rep_pen
clipped_rprange = rprange if rprange > 0 else tokens.shape[-1]
penalty_arange = np.roll(
np.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]),
generated_index,
axis=-1,
)
# Make a new array with the same length as the tokens array but with
# each element replaced by the value at the corresponding index in the
# logits array; e.g.
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
penalty_logits = np.take(scores, tokens)
# Repetition penalty slope
if rpslope != 0.0 and rprange > 0:
_penalty = (penalty_arange / (rprange - 1)) * 2 - 1
_penalty = (rpslope * _penalty) / (1 + np.abs(_penalty) * (rpslope - 1))
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
repetition_penalty = _penalty
# Divide positive values by repetition_penalty and multiply negative
# values by repetition_penalty (the academic publication that described
# this technique actually just only divided, but that would cause tokens
# with negative logits to become more likely, which is obviously wrong)
if cls.use_alt_rep_pen:
penalty_logits = np.where(
penalty_arange >= 0,
penalty_logits - np.log(repetition_penalty),
penalty_logits,
)
else:
penalty_logits = np.where(
penalty_arange >= 0,
np.where(
penalty_logits > 0,
penalty_logits / repetition_penalty,
penalty_logits * repetition_penalty,
),
penalty_logits,
)
# Finally, put those penalized logit values back into their original
# positions in the logits array
scores[tokens] = penalty_logits
return scores