Fix `TypicalLogitsWarper` argument typing

This commit is contained in:
Gnome Ann 2022-03-27 16:59:23 -04:00
parent d5989d4c62
commit bbd0a83fef
1 changed files with 1 additions and 1 deletions

View File

@ -105,7 +105,7 @@ class TypicalLogitsWarper(LogitsWarper):
Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf
'''
def __init__(self, typical: float, filter_value: -float("Inf"), min_tokens_to_keep: int = 1):
def __init__(self, typical: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
typical = float(typical)
if typical < 0 or typical > 1.0:
raise ValueError(f"`typical` has to be a float >= 0 and <= 1, but is {typical}")