Fix an unfortunate typo in top-a warper

This commit is contained in:
Gnome Ann 2022-06-10 22:34:14 -04:00
parent fdb2a7fa4c
commit 42b7a327b2
1 changed files with 1 additions and 1 deletions

View File

@ -168,7 +168,7 @@ class TopALogitsWarper(LogitsWarper):
# Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept)
probs_max = probs[..., 0, None]
sorted_indices_to_remove = probs >= probs_max * probs_max * self.top_a
sorted_indices_to_remove = probs < probs_max * probs_max * self.top_a
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep