diff --git a/warpers.py b/warpers.py index ef46f318..0adf5740 100644 --- a/warpers.py +++ b/warpers.py @@ -1,4 +1,4 @@ -''' +""" This file is AGPL-licensed. Some of the code in this file is from Clover Edition: @@ -25,54 +25,146 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' + +--- + +Some of the code in this file is also from Hugging Face logitsTransformers: +https://github.com/huggingface/transformers + +Transformers is licensed under the Apache-2.0 License. The changes made to this +file are mostly porting warper code to the torch methods. +""" +# Comments mostly taken from tpu_mtj_backend.py + +from __future__ import annotations import torch -from transformers import LogitsWarper +import jax +import jax.numpy as jnp +import numpy as np +import tpu_mtj_backend -class AdvancedRepetitionPenaltyLogitsProcessor(LogitsWarper): - def __init__(self, *args, **kwargs): - pass +class Warper: + @staticmethod + def from_id(warper_id: int) -> Warper: + return { + 0: TopK, + 1: TopA, + 2: TopP, + 3: TailFree, + 4: Typical, + 5: Temperature, + 6: RepetitionPenalty, + }[warper_id] - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - self.penalty_range = int(self.penalty_range) - clipped_penalty_range = min(input_ids.shape[-1], self.penalty_range) - if self.penalty != 1.0: - if self.penalty_range > 0: - if clipped_penalty_range < input_ids.shape[1]: - input_ids = input_ids[..., -clipped_penalty_range:] +class Temperature(Warper): + """Temperature (just divide the logits by the temperature)""" - if self.penalty_slope != 0: - _penalty = (torch.arange(self.penalty_range, dtype=scores.dtype, device=scores.device)/(self.penalty_range - 1)) * 2. - 1 - _penalty = (self.penalty_slope * _penalty) / (1 + torch.abs(_penalty) * (self.penalty_slope - 1)) - _penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (self.penalty - 1) - self.penalty = _penalty[..., -clipped_penalty_range:] + temperature: float = 0.5 - score = torch.gather(scores, 1, input_ids) - if self.use_alt_rep_pen: - score = score - torch.log(self.penalty) - else: - score = torch.where(score <= 0, score * self.penalty, score / self.penalty) - scores.scatter_(1, input_ids, score) + @classmethod + def torch(cls, scores: torch.Tensor) -> torch.Tensor: + return scores / cls.temperature + @classmethod + def jax(cls, scores: jnp.array) -> jnp.array: + return scores / cls.value + + +class TopP(Warper): + """ + Top-p (after sorting the remaining tokens again in descending order of + logit, remove the ones that have cumulative softmax probability + greater than p) + """ + + top_p: float = 0.9 + + @classmethod + def torch(cls, scores: torch.Tensor) -> torch.Tensor: + sorted_logits, sorted_indices = torch.sort(scores, descending=False) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs <= (1 - cls.value) + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + return scores.masked_fill(indices_to_remove, -np.inf) + + @classmethod + def jax(cls, scores: jnp.array) -> jnp.array: + # Sort the logits array in descending order, replace every element + # with e (Euler's number) to the power of that element, and divide + # each element of the new array by the sum of the elements in the + # new array + sorted_logits = -jnp.sort(-scores) + probabilities = jax.nn.softmax(sorted_logits) + # Calculate cumulative_probabilities as the prefix-sum array of + # probabilities + cumulative_probabilities = jnp.cumsum(probabilities, axis=-1) + # We want to remove tokens with cumulative probability higher + # than top_p + sorted_indices_to_remove = cumulative_probabilities > cls.value + # Don't ever remove the token with the highest logit, even if + # the probability is higher than top_p + sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) + # Unsort and remove + _, indices_to_remove = jax.lax.sort_key_val( + jnp.argsort(-scores), + sorted_indices_to_remove, + ) + return jnp.where(indices_to_remove, -jnp.inf, scores) + + +class TopK(Warper): + """ + Top-k (keep only the k tokens with the highest logits and remove the rest, + by setting their logits to negative infinity) + """ + + top_k: int = 0 + + @classmethod + def torch(cls, scores: torch.Tensor) -> torch.Tensor: + top_k = min(max(cls.top_k, 1), scores.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None] + scores = scores.masked_fill(indices_to_remove, -np.inf) return scores + @classmethod + def jax(cls, scores: jnp.array) -> jnp.array: + # After sorting the logits array in descending order, + # sorted_indices_to_remove is a 1D array that is True for tokens + # in the sorted logits array we want to remove and False for ones + # we want to keep, in this case the first top_k elements will be + # False and the rest will be True + sorted_indices_to_remove = np.arange(len(scores)) >= cls.top_k + # Unsort the logits array back to its original configuration and + # remove tokens we need to remove + _, indices_to_remove = jax.lax.sort_key_val( + np.argsort(-scores), + sorted_indices_to_remove, + ) + return np.where(indices_to_remove, -np.inf, scores) -class TailFreeLogitsWarper(LogitsWarper): - def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): - tfs = float(tfs) - if tfs < 0 or tfs > 1.0: - raise ValueError(f"`tfs` has to be a float >= 0 and <= 1, but is {tfs}") - self.tfs = tfs - self.filter_value = filter_value - self.min_tokens_to_keep = min_tokens_to_keep +class TailFree(Warper): + """ + Tail free sampling (basically top-p a second time on remaining tokens except + it's the "cumulative normalized absolute second finite differences of the + softmax probabilities" instead of just the cumulative softmax probabilities) + """ - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - if self.filter_value >= 1.0: - return scores + tfs: float = 1.0 + + @classmethod + def torch(cls, scores: torch.Tensor) -> torch.Tensor: sorted_logits, sorted_indices = torch.sort(scores, descending=True) probs = sorted_logits.softmax(dim=-1) @@ -82,7 +174,7 @@ class TailFreeLogitsWarper(LogitsWarper): normalized_d2_cdf = normalized_d2.cumsum(dim=-1) # Remove tokens with CDF value above the threshold (token with 0 are kept) - sorted_indices_to_remove = normalized_d2_cdf > self.tfs + sorted_indices_to_remove = normalized_d2_cdf > cls.tfs # Centre the distribution around the cutoff as in the original implementation of the algorithm sorted_indices_to_remove = torch.cat( @@ -94,32 +186,64 @@ class TailFreeLogitsWarper(LogitsWarper): dim=-1, ) - if self.min_tokens_to_keep > 1: - # Keep at least min_tokens_to_keep - sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 - - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - scores = scores.masked_fill(indices_to_remove, self.filter_value) + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + scores = scores.masked_fill(indices_to_remove, -np.inf) return scores + @classmethod + def jax(cls, scores: jnp.array) -> jnp.array: + # Sort in descending order + sorted_logits = -np.sort(-scores) -class TypicalLogitsWarper(LogitsWarper): - ''' - Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf - ''' + # Softmax again + probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True) - def __init__(self, typical: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): - typical = float(typical) - if typical < 0 or typical > 1.0: - raise ValueError(f"`typical` has to be a float >= 0 and <= 1, but is {typical}") - self.typical = typical - self.filter_value = filter_value - self.min_tokens_to_keep = min_tokens_to_keep + # Calculate the second finite differences of that array (i.e. + # calculate the difference array and then calculate the difference + # array of the difference array) + d2 = np.diff(np.diff(probabilities)) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - if self.filter_value >= 1.0: - return scores + # Get the absolute values of all those second finite differences + d2 = np.abs(d2) + # Normalize (all elements in the array are divided by the sum of the + # array's elements) + d2 = d2 / d2.sum(axis=-1, keepdims=True) + + # Get the prefix-sum array + cumulative_d2 = np.cumsum(d2, axis=-1) + + # We will remove the tokens with a cumulative normalized absolute + # second finite difference larger than the TFS value + sorted_indices_to_remove = cumulative_d2 > cls.tfs + + # Don't remove the token with the highest logit + sorted_indices_to_remove[0] = False + + # Since the d2 array has two fewer elements than the logits array, + # we'll add two extra Trues to the end + sorted_indices_to_remove = np.pad( + sorted_indices_to_remove, + (0, 2), + constant_values=True, + ) + # Unsort and remove + _, indices_to_remove = jax.lax.sort_key_val( + np.argsort(-scores), + sorted_indices_to_remove, + ) + return np.where(indices_to_remove, -np.inf, scores) + + +class Typical(Warper): + """Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf""" + + typical: float = 1.0 + + @classmethod + def torch(cls, scores: torch.Tensor) -> torch.Tensor: # Compute softmax probabilities and the natural logarithms of them probs = scores.softmax(dim=-1) log_probs = probs.log() @@ -141,42 +265,261 @@ class TypicalLogitsWarper(LogitsWarper): # threshold) _, sorted_indices = torch.sort(entropy_deviation) sorted_logits = probs.gather(-1, sorted_indices) - sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= self.typical + sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= cls.typical sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dims=-1) - min_tokens_to_keep = max(self.min_tokens_to_keep, 1) + min_tokens_to_keep = 1 # Keep at least min_tokens_to_keep - sorted_indices_to_remove[..., : min_tokens_to_keep] = 0 + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - scores = scores.masked_fill(indices_to_remove, self.filter_value) + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + scores = scores.masked_fill(indices_to_remove, -np.inf) return scores + @classmethod + def jax(cls, scores: jnp.array) -> jnp.array: + # Compute softmax probabilities and the natural logarithms of them + probs = jax.nn.softmax(scores) + with np.errstate(divide="ignore"): + log_probs = np.log(probs) -class TopALogitsWarper(LogitsWarper): - def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): - top_a = float(top_a) - if top_a < 0 or top_a > 1.0: - raise ValueError(f"`top_a` has to be a float >= 0 and <= 1, but is {top_a}") - self.top_a = top_a - self.filter_value = filter_value - self.min_tokens_to_keep = min_tokens_to_keep + # Compute the negative of entropy, which is the sum of p*ln(p) for all p + # in the set of softmax probabilities of the logits + neg_entropy = np.nansum(probs * log_probs, axis=-1, keepdims=True) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - if self.filter_value >= 1.0: - return scores + # Determine absolute difference between the negative entropy and the + # log probabilities + entropy_deviation = np.abs(neg_entropy - log_probs) + # Keep certain tokens such that the sum of the entropy_deviation of the + # kept tokens is the smallest possible value such that the sum of the + # softmax probabilities of the kept tokens is at least the threshold + # value (by sorting the tokens in ascending order of entropy_deviation + # and then keeping the smallest possible number of tokens from the + # beginning such that sum of softmax probabilities is at or above the + # threshold) + _, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs) + sorted_indices_to_remove = np.cumsum(sorted_logits, axis=-1) >= cls.typical + sorted_indices_to_remove = np.roll(sorted_indices_to_remove, 1, axis=-1) + sorted_indices_to_remove[0] = False + + # Unsort and remove + _, indices_to_remove = jax.lax.sort_key_val( + jnp.argsort(entropy_deviation), + sorted_indices_to_remove, + ) + return np.where(indices_to_remove, -jnp.inf, scores) + + +class TopA(Warper): + """ + Top-a (remove all tokens that have softmax probability less than *m^2 where + m is the maximum softmax probability) + """ + + top_a: float = 0.0 + + @classmethod + def torch(cls, scores: torch.Tensor) -> torch.Tensor: sorted_logits, sorted_indices = torch.sort(scores, descending=True) probs = sorted_logits.softmax(dim=-1) # Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept) probs_max = probs[..., 0, None] - sorted_indices_to_remove = probs < probs_max * probs_max * self.top_a + sorted_indices_to_remove = probs < probs_max * probs_max * cls.top_a - if self.min_tokens_to_keep > 1: - # Keep at least min_tokens_to_keep - sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 - - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - scores = scores.masked_fill(indices_to_remove, self.filter_value) + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + scores = scores.masked_fill(indices_to_remove, -np.inf) + return scores + + @classmethod + def jax(cls, scores: jnp.array) -> jnp.array: + # Replace every element in the logits array + # with e (Euler's number) to the power of that element, and divide + # each element of the new array by the sum of the elements in the + # new array + probabilities = np.array(jax.nn.softmax(scores), copy=True) + # Find the largest probability + probs_max = probabilities.max() + # Remove tokens + return np.where( + probabilities < probs_max * probs_max * cls.top_a, -np.inf, scores + ) + + +class RepetitionPenalty(Warper): + rep_pen: float = 1.0 + rep_pen_slope: float = 0.0 + rep_pen_range: int = 0 + use_alt_rep_pen: bool = False + + @classmethod + def torch(cls, scores: torch.Tensor) -> torch.Tensor: + cls.rep_pen_range = int(cls.rep_pen_range) + clipped_penalty_range = min(input_ids.shape[-1], cls.rep_pen_range) + + if cls.rep_pen != 1.0: + if cls.rep_pen_range > 0: + if clipped_penalty_range < input_ids.shape[1]: + input_ids = input_ids[..., -clipped_penalty_range:] + + if cls.rep_pen_slope != 0: + _penalty = ( + torch.arange( + cls.rep_pen_range, dtype=scores.dtype, device=scores.device + ) + / (cls.rep_pen_range - 1) + ) * 2.0 - 1 + _penalty = (cls.rep_pen_slope * _penalty) / ( + 1 + torch.abs(_penalty) * (cls.rep_pen_slope - 1) + ) + _penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (cls.rep_pen - 1) + cls.rep_pen = _penalty[..., -clipped_penalty_range:] + + score = torch.gather(scores, 1, input_ids) + if cls.use_alt_rep_pen: + score = score - torch.log(cls.rep_pen) + else: + score = torch.where( + score <= 0, score * cls.rep_pen, score / cls.rep_pen + ) + scores.scatter_(1, input_ids, score) + + return scores + + @classmethod + # def jax_static(cls, scores: jnp.array) -> jnp.array: + def jax_static( + cls, + logits: jnp.array, + tokens: jnp.array, + generated_index, + ) -> jnp.array: + """ + This gets called by generate_loop_fn to apply repetition penalty + to the 1D array logits using the provided 1D array of tokens to penalize + """ + rpslope = jnp.int32(cls.rep_pen_slope) + rprange = jnp.int32(cls.rep_pen_range) + repetition_penalty = cls.rep_pen + + clipped_rprange = jax.lax.cond( + rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange + ) + + penalty_arange = jnp.roll( + jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]), + generated_index, + axis=-1, + ) + # Make a new array with the same length as the tokens array but with + # each element replaced by the value at the corresponding index in the + # logits array; e.g. + # if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1], + # then penalty_logits will be [77, 5, 3, 98, 3, 98, 5] + penalty_logits = jnp.take(logits, tokens) + # Repetition penalty slope + def apply_slope(carry): + repetition_penalty, rprange = carry + _penalty = (penalty_arange / (rprange - 1)) * 2 - 1 + _penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1)) + _penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1) + return _penalty + + repetition_penalty = jax.lax.cond( + (rpslope != 0.0) + & (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash + apply_slope, + lambda carry: jnp.full(tokens.shape, carry[0]), + (repetition_penalty, rprange), + ) + # Divide positive values by repetition_penalty and multiply negative + # values by repetition_penalty (the academic publication that described + # this technique actually just only divided, but that would cause tokens + # with negative logits to become more likely, which is obviously wrong) + if cls.use_alt_rep_pen: + penalty_logits = jnp.where( + penalty_arange >= 0, + penalty_logits - jnp.log(repetition_penalty), + penalty_logits, + ) + else: + penalty_logits = jnp.where( + penalty_arange >= 0, + jnp.where( + penalty_logits > 0, + penalty_logits / repetition_penalty, + penalty_logits * repetition_penalty, + ), + penalty_logits, + ) + # Finally, put those penalized logit values back into their original + # positions in the logits array + return logits.at[tokens].set(penalty_logits) + + @classmethod + def jax_dynamic( + cls, + scores: jnp.array, + tokens: jnp.array, + generated_index, + ) -> jnp.array: + """ + This gets called by generate_loop_fn to apply repetition penalty + to the 1D array logits using the provided 1D array of tokens to penalize + """ + tokens = np.minimum( + tokens, tpu_mtj_backend.params["n_vocab"] - 1 + ) # https://github.com/google/jax/issues/3774 + + rpslope = np.int32(cls.rep_pen_slope) + rprange = np.int32(cls.rep_pen_range) + repetition_penalty = cls.rep_pen + + clipped_rprange = rprange if rprange > 0 else tokens.shape[-1] + penalty_arange = np.roll( + np.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]), + generated_index, + axis=-1, + ) + # Make a new array with the same length as the tokens array but with + # each element replaced by the value at the corresponding index in the + # logits array; e.g. + # if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1], + # then penalty_logits will be [77, 5, 3, 98, 3, 98, 5] + penalty_logits = np.take(scores, tokens) + # Repetition penalty slope + if rpslope != 0.0 and rprange > 0: + _penalty = (penalty_arange / (rprange - 1)) * 2 - 1 + _penalty = (rpslope * _penalty) / (1 + np.abs(_penalty) * (rpslope - 1)) + _penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1) + repetition_penalty = _penalty + # Divide positive values by repetition_penalty and multiply negative + # values by repetition_penalty (the academic publication that described + # this technique actually just only divided, but that would cause tokens + # with negative logits to become more likely, which is obviously wrong) + if cls.use_alt_rep_pen: + penalty_logits = np.where( + penalty_arange >= 0, + penalty_logits - np.log(repetition_penalty), + penalty_logits, + ) + + else: + penalty_logits = np.where( + penalty_arange >= 0, + np.where( + penalty_logits > 0, + penalty_logits / repetition_penalty, + penalty_logits * repetition_penalty, + ), + penalty_logits, + ) + # Finally, put those penalized logit values back into their original + # positions in the logits array + scores[tokens] = penalty_logits return scores