Fix `TypicalLogitsWarper` argument typing
This commit is contained in:
parent
d5989d4c62
commit
bbd0a83fef
|
@ -105,7 +105,7 @@ class TypicalLogitsWarper(LogitsWarper):
|
|||
Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf
|
||||
'''
|
||||
|
||||
def __init__(self, typical: float, filter_value: -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
def __init__(self, typical: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
typical = float(typical)
|
||||
if typical < 0 or typical > 1.0:
|
||||
raise ValueError(f"`typical` has to be a float >= 0 and <= 1, but is {typical}")
|
||||
|
|
Loading…
Reference in New Issue