Model: Samplers pt. 1

This commit is contained in:
somebody
2023-02-26 16:09:22 -06:00
parent f771ae38cf
commit f882979c88

View File

@@ -1,4 +1,4 @@
''' """
This file is AGPL-licensed. This file is AGPL-licensed.
Some of the code in this file is from Clover Edition: Some of the code in this file is from Clover Edition:
@@ -25,54 +25,146 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
'''
---
Some of the code in this file is also from Hugging Face logitsTransformers:
https://github.com/huggingface/transformers
Transformers is licensed under the Apache-2.0 License. The changes made to this
file are mostly porting warper code to the torch methods.
"""
# Comments mostly taken from tpu_mtj_backend.py
from __future__ import annotations
import torch import torch
from transformers import LogitsWarper import jax
import jax.numpy as jnp
import numpy as np
import tpu_mtj_backend
class AdvancedRepetitionPenaltyLogitsProcessor(LogitsWarper): class Warper:
def __init__(self, *args, **kwargs): @staticmethod
pass def from_id(warper_id: int) -> Warper:
return {
0: TopK,
1: TopA,
2: TopP,
3: TailFree,
4: Typical,
5: Temperature,
6: RepetitionPenalty,
}[warper_id]
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
self.penalty_range = int(self.penalty_range)
clipped_penalty_range = min(input_ids.shape[-1], self.penalty_range)
if self.penalty != 1.0: class Temperature(Warper):
if self.penalty_range > 0: """Temperature (just divide the logits by the temperature)"""
if clipped_penalty_range < input_ids.shape[1]:
input_ids = input_ids[..., -clipped_penalty_range:]
if self.penalty_slope != 0: temperature: float = 0.5
_penalty = (torch.arange(self.penalty_range, dtype=scores.dtype, device=scores.device)/(self.penalty_range - 1)) * 2. - 1
_penalty = (self.penalty_slope * _penalty) / (1 + torch.abs(_penalty) * (self.penalty_slope - 1))
_penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (self.penalty - 1)
self.penalty = _penalty[..., -clipped_penalty_range:]
score = torch.gather(scores, 1, input_ids) @classmethod
if self.use_alt_rep_pen: def torch(cls, scores: torch.Tensor) -> torch.Tensor:
score = score - torch.log(self.penalty) return scores / cls.temperature
else:
score = torch.where(score <= 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, input_ids, score)
@classmethod
def jax(cls, scores: jnp.array) -> jnp.array:
return scores / cls.value
class TopP(Warper):
"""
Top-p (after sorting the remaining tokens again in descending order of
logit, remove the ones that have cumulative softmax probability
greater than p)
"""
top_p: float = 0.9
@classmethod
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - cls.value)
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
return scores.masked_fill(indices_to_remove, -np.inf)
@classmethod
def jax(cls, scores: jnp.array) -> jnp.array:
# Sort the logits array in descending order, replace every element
# with e (Euler's number) to the power of that element, and divide
# each element of the new array by the sum of the elements in the
# new array
sorted_logits = -jnp.sort(-scores)
probabilities = jax.nn.softmax(sorted_logits)
# Calculate cumulative_probabilities as the prefix-sum array of
# probabilities
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1)
# We want to remove tokens with cumulative probability higher
# than top_p
sorted_indices_to_remove = cumulative_probabilities > cls.value
# Don't ever remove the token with the highest logit, even if
# the probability is higher than top_p
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
# Unsort and remove
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(-scores),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, scores)
class TopK(Warper):
"""
Top-k (keep only the k tokens with the highest logits and remove the rest,
by setting their logits to negative infinity)
"""
top_k: int = 0
@classmethod
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
top_k = min(max(cls.top_k, 1), scores.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
scores = scores.masked_fill(indices_to_remove, -np.inf)
return scores return scores
@classmethod
def jax(cls, scores: jnp.array) -> jnp.array:
# After sorting the logits array in descending order,
# sorted_indices_to_remove is a 1D array that is True for tokens
# in the sorted logits array we want to remove and False for ones
# we want to keep, in this case the first top_k elements will be
# False and the rest will be True
sorted_indices_to_remove = np.arange(len(scores)) >= cls.top_k
# Unsort the logits array back to its original configuration and
# remove tokens we need to remove
_, indices_to_remove = jax.lax.sort_key_val(
np.argsort(-scores),
sorted_indices_to_remove,
)
return np.where(indices_to_remove, -np.inf, scores)
class TailFreeLogitsWarper(LogitsWarper):
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): class TailFree(Warper):
tfs = float(tfs) """
if tfs < 0 or tfs > 1.0: Tail free sampling (basically top-p a second time on remaining tokens except
raise ValueError(f"`tfs` has to be a float >= 0 and <= 1, but is {tfs}") it's the "cumulative normalized absolute second finite differences of the
self.tfs = tfs softmax probabilities" instead of just the cumulative softmax probabilities)
self.filter_value = filter_value """
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: tfs: float = 1.0
if self.filter_value >= 1.0:
return scores @classmethod
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=True) sorted_logits, sorted_indices = torch.sort(scores, descending=True)
probs = sorted_logits.softmax(dim=-1) probs = sorted_logits.softmax(dim=-1)
@@ -82,7 +174,7 @@ class TailFreeLogitsWarper(LogitsWarper):
normalized_d2_cdf = normalized_d2.cumsum(dim=-1) normalized_d2_cdf = normalized_d2.cumsum(dim=-1)
# Remove tokens with CDF value above the threshold (token with 0 are kept) # Remove tokens with CDF value above the threshold (token with 0 are kept)
sorted_indices_to_remove = normalized_d2_cdf > self.tfs sorted_indices_to_remove = normalized_d2_cdf > cls.tfs
# Centre the distribution around the cutoff as in the original implementation of the algorithm # Centre the distribution around the cutoff as in the original implementation of the algorithm
sorted_indices_to_remove = torch.cat( sorted_indices_to_remove = torch.cat(
@@ -94,32 +186,64 @@ class TailFreeLogitsWarper(LogitsWarper):
dim=-1, dim=-1,
) )
if self.min_tokens_to_keep > 1: indices_to_remove = sorted_indices_to_remove.scatter(
# Keep at least min_tokens_to_keep 1, sorted_indices, sorted_indices_to_remove
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 )
scores = scores.masked_fill(indices_to_remove, -np.inf)
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores return scores
@classmethod
def jax(cls, scores: jnp.array) -> jnp.array:
# Sort in descending order
sorted_logits = -np.sort(-scores)
class TypicalLogitsWarper(LogitsWarper): # Softmax again
''' probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True)
Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf
'''
def __init__(self, typical: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): # Calculate the second finite differences of that array (i.e.
typical = float(typical) # calculate the difference array and then calculate the difference
if typical < 0 or typical > 1.0: # array of the difference array)
raise ValueError(f"`typical` has to be a float >= 0 and <= 1, but is {typical}") d2 = np.diff(np.diff(probabilities))
self.typical = typical
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: # Get the absolute values of all those second finite differences
if self.filter_value >= 1.0: d2 = np.abs(d2)
return scores
# Normalize (all elements in the array are divided by the sum of the
# array's elements)
d2 = d2 / d2.sum(axis=-1, keepdims=True)
# Get the prefix-sum array
cumulative_d2 = np.cumsum(d2, axis=-1)
# We will remove the tokens with a cumulative normalized absolute
# second finite difference larger than the TFS value
sorted_indices_to_remove = cumulative_d2 > cls.tfs
# Don't remove the token with the highest logit
sorted_indices_to_remove[0] = False
# Since the d2 array has two fewer elements than the logits array,
# we'll add two extra Trues to the end
sorted_indices_to_remove = np.pad(
sorted_indices_to_remove,
(0, 2),
constant_values=True,
)
# Unsort and remove
_, indices_to_remove = jax.lax.sort_key_val(
np.argsort(-scores),
sorted_indices_to_remove,
)
return np.where(indices_to_remove, -np.inf, scores)
class Typical(Warper):
"""Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf"""
typical: float = 1.0
@classmethod
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
# Compute softmax probabilities and the natural logarithms of them # Compute softmax probabilities and the natural logarithms of them
probs = scores.softmax(dim=-1) probs = scores.softmax(dim=-1)
log_probs = probs.log() log_probs = probs.log()
@@ -141,42 +265,261 @@ class TypicalLogitsWarper(LogitsWarper):
# threshold) # threshold)
_, sorted_indices = torch.sort(entropy_deviation) _, sorted_indices = torch.sort(entropy_deviation)
sorted_logits = probs.gather(-1, sorted_indices) sorted_logits = probs.gather(-1, sorted_indices)
sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= self.typical sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= cls.typical
sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dims=-1) sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dims=-1)
min_tokens_to_keep = max(self.min_tokens_to_keep, 1) min_tokens_to_keep = 1
# Keep at least min_tokens_to_keep # Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., : min_tokens_to_keep] = 0 sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) indices_to_remove = sorted_indices_to_remove.scatter(
scores = scores.masked_fill(indices_to_remove, self.filter_value) 1, sorted_indices, sorted_indices_to_remove
)
scores = scores.masked_fill(indices_to_remove, -np.inf)
return scores return scores
@classmethod
def jax(cls, scores: jnp.array) -> jnp.array:
# Compute softmax probabilities and the natural logarithms of them
probs = jax.nn.softmax(scores)
with np.errstate(divide="ignore"):
log_probs = np.log(probs)
class TopALogitsWarper(LogitsWarper): # Compute the negative of entropy, which is the sum of p*ln(p) for all p
def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): # in the set of softmax probabilities of the logits
top_a = float(top_a) neg_entropy = np.nansum(probs * log_probs, axis=-1, keepdims=True)
if top_a < 0 or top_a > 1.0:
raise ValueError(f"`top_a` has to be a float >= 0 and <= 1, but is {top_a}")
self.top_a = top_a
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: # Determine absolute difference between the negative entropy and the
if self.filter_value >= 1.0: # log probabilities
return scores entropy_deviation = np.abs(neg_entropy - log_probs)
# Keep certain tokens such that the sum of the entropy_deviation of the
# kept tokens is the smallest possible value such that the sum of the
# softmax probabilities of the kept tokens is at least the threshold
# value (by sorting the tokens in ascending order of entropy_deviation
# and then keeping the smallest possible number of tokens from the
# beginning such that sum of softmax probabilities is at or above the
# threshold)
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
sorted_indices_to_remove = np.cumsum(sorted_logits, axis=-1) >= cls.typical
sorted_indices_to_remove = np.roll(sorted_indices_to_remove, 1, axis=-1)
sorted_indices_to_remove[0] = False
# Unsort and remove
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(entropy_deviation),
sorted_indices_to_remove,
)
return np.where(indices_to_remove, -jnp.inf, scores)
class TopA(Warper):
"""
Top-a (remove all tokens that have softmax probability less than *m^2 where
m is the maximum softmax probability)
"""
top_a: float = 0.0
@classmethod
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=True) sorted_logits, sorted_indices = torch.sort(scores, descending=True)
probs = sorted_logits.softmax(dim=-1) probs = sorted_logits.softmax(dim=-1)
# Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept) # Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept)
probs_max = probs[..., 0, None] probs_max = probs[..., 0, None]
sorted_indices_to_remove = probs < probs_max * probs_max * self.top_a sorted_indices_to_remove = probs < probs_max * probs_max * cls.top_a
if self.min_tokens_to_keep > 1: indices_to_remove = sorted_indices_to_remove.scatter(
# Keep at least min_tokens_to_keep 1, sorted_indices, sorted_indices_to_remove
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 )
scores = scores.masked_fill(indices_to_remove, -np.inf)
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) return scores
scores = scores.masked_fill(indices_to_remove, self.filter_value)
@classmethod
def jax(cls, scores: jnp.array) -> jnp.array:
# Replace every element in the logits array
# with e (Euler's number) to the power of that element, and divide
# each element of the new array by the sum of the elements in the
# new array
probabilities = np.array(jax.nn.softmax(scores), copy=True)
# Find the largest probability
probs_max = probabilities.max()
# Remove tokens
return np.where(
probabilities < probs_max * probs_max * cls.top_a, -np.inf, scores
)
class RepetitionPenalty(Warper):
rep_pen: float = 1.0
rep_pen_slope: float = 0.0
rep_pen_range: int = 0
use_alt_rep_pen: bool = False
@classmethod
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
cls.rep_pen_range = int(cls.rep_pen_range)
clipped_penalty_range = min(input_ids.shape[-1], cls.rep_pen_range)
if cls.rep_pen != 1.0:
if cls.rep_pen_range > 0:
if clipped_penalty_range < input_ids.shape[1]:
input_ids = input_ids[..., -clipped_penalty_range:]
if cls.rep_pen_slope != 0:
_penalty = (
torch.arange(
cls.rep_pen_range, dtype=scores.dtype, device=scores.device
)
/ (cls.rep_pen_range - 1)
) * 2.0 - 1
_penalty = (cls.rep_pen_slope * _penalty) / (
1 + torch.abs(_penalty) * (cls.rep_pen_slope - 1)
)
_penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (cls.rep_pen - 1)
cls.rep_pen = _penalty[..., -clipped_penalty_range:]
score = torch.gather(scores, 1, input_ids)
if cls.use_alt_rep_pen:
score = score - torch.log(cls.rep_pen)
else:
score = torch.where(
score <= 0, score * cls.rep_pen, score / cls.rep_pen
)
scores.scatter_(1, input_ids, score)
return scores
@classmethod
# def jax_static(cls, scores: jnp.array) -> jnp.array:
def jax_static(
cls,
logits: jnp.array,
tokens: jnp.array,
generated_index,
) -> jnp.array:
"""
This gets called by generate_loop_fn to apply repetition penalty
to the 1D array logits using the provided 1D array of tokens to penalize
"""
rpslope = jnp.int32(cls.rep_pen_slope)
rprange = jnp.int32(cls.rep_pen_range)
repetition_penalty = cls.rep_pen
clipped_rprange = jax.lax.cond(
rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange
)
penalty_arange = jnp.roll(
jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]),
generated_index,
axis=-1,
)
# Make a new array with the same length as the tokens array but with
# each element replaced by the value at the corresponding index in the
# logits array; e.g.
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
penalty_logits = jnp.take(logits, tokens)
# Repetition penalty slope
def apply_slope(carry):
repetition_penalty, rprange = carry
_penalty = (penalty_arange / (rprange - 1)) * 2 - 1
_penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1))
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
return _penalty
repetition_penalty = jax.lax.cond(
(rpslope != 0.0)
& (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash
apply_slope,
lambda carry: jnp.full(tokens.shape, carry[0]),
(repetition_penalty, rprange),
)
# Divide positive values by repetition_penalty and multiply negative
# values by repetition_penalty (the academic publication that described
# this technique actually just only divided, but that would cause tokens
# with negative logits to become more likely, which is obviously wrong)
if cls.use_alt_rep_pen:
penalty_logits = jnp.where(
penalty_arange >= 0,
penalty_logits - jnp.log(repetition_penalty),
penalty_logits,
)
else:
penalty_logits = jnp.where(
penalty_arange >= 0,
jnp.where(
penalty_logits > 0,
penalty_logits / repetition_penalty,
penalty_logits * repetition_penalty,
),
penalty_logits,
)
# Finally, put those penalized logit values back into their original
# positions in the logits array
return logits.at[tokens].set(penalty_logits)
@classmethod
def jax_dynamic(
cls,
scores: jnp.array,
tokens: jnp.array,
generated_index,
) -> jnp.array:
"""
This gets called by generate_loop_fn to apply repetition penalty
to the 1D array logits using the provided 1D array of tokens to penalize
"""
tokens = np.minimum(
tokens, tpu_mtj_backend.params["n_vocab"] - 1
) # https://github.com/google/jax/issues/3774
rpslope = np.int32(cls.rep_pen_slope)
rprange = np.int32(cls.rep_pen_range)
repetition_penalty = cls.rep_pen
clipped_rprange = rprange if rprange > 0 else tokens.shape[-1]
penalty_arange = np.roll(
np.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]),
generated_index,
axis=-1,
)
# Make a new array with the same length as the tokens array but with
# each element replaced by the value at the corresponding index in the
# logits array; e.g.
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
penalty_logits = np.take(scores, tokens)
# Repetition penalty slope
if rpslope != 0.0 and rprange > 0:
_penalty = (penalty_arange / (rprange - 1)) * 2 - 1
_penalty = (rpslope * _penalty) / (1 + np.abs(_penalty) * (rpslope - 1))
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
repetition_penalty = _penalty
# Divide positive values by repetition_penalty and multiply negative
# values by repetition_penalty (the academic publication that described
# this technique actually just only divided, but that would cause tokens
# with negative logits to become more likely, which is obviously wrong)
if cls.use_alt_rep_pen:
penalty_logits = np.where(
penalty_arange >= 0,
penalty_logits - np.log(repetition_penalty),
penalty_logits,
)
else:
penalty_logits = np.where(
penalty_arange >= 0,
np.where(
penalty_logits > 0,
penalty_logits / repetition_penalty,
penalty_logits * repetition_penalty,
),
penalty_logits,
)
# Finally, put those penalized logit values back into their original
# positions in the logits array
scores[tokens] = penalty_logits
return scores return scores