Typo fix in `TypicalLogitsWarper`

This commit is contained in:
Gnome Ann 2022-03-27 17:08:57 -04:00
parent bbd0a83fef
commit e2cd49d552
1 changed files with 1 additions and 1 deletions

View File

@ -139,7 +139,7 @@ class TypicalLogitsWarper(LogitsWarper):
_, sorted_indices = torch.sort(entropy_deviation)
sorted_logits = probs.gather(-1, sorted_indices)
sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= self.typical
sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dim=-1)
sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dims=-1)
min_tokens_to_keep = max(self.min_tokens_to_keep, 1)
# Keep at least min_tokens_to_keep