diff --git a/warpers.py b/warpers.py index 9c1f88eb..fb683f50 100644 --- a/warpers.py +++ b/warpers.py @@ -168,7 +168,7 @@ class TopALogitsWarper(LogitsWarper): # Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept) probs_max = probs[..., 0, None] - sorted_indices_to_remove = probs >= probs_max * probs_max * self.top_a + sorted_indices_to_remove = probs < probs_max * probs_max * self.top_a if self.min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep