Typo fix in `TypicalLogitsWarper`
This commit is contained in:
parent
bbd0a83fef
commit
e2cd49d552
|
@ -139,7 +139,7 @@ class TypicalLogitsWarper(LogitsWarper):
|
|||
_, sorted_indices = torch.sort(entropy_deviation)
|
||||
sorted_logits = probs.gather(-1, sorted_indices)
|
||||
sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= self.typical
|
||||
sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dim=-1)
|
||||
sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dims=-1)
|
||||
|
||||
min_tokens_to_keep = max(self.min_tokens_to_keep, 1)
|
||||
# Keep at least min_tokens_to_keep
|
||||
|
|
Loading…
Reference in New Issue