Fix `TypicalLogitsWarper` argument typing
This commit is contained in:
parent
d5989d4c62
commit
bbd0a83fef
|
@ -105,7 +105,7 @@ class TypicalLogitsWarper(LogitsWarper):
|
||||||
Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf
|
Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, typical: float, filter_value: -float("Inf"), min_tokens_to_keep: int = 1):
|
def __init__(self, typical: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||||
typical = float(typical)
|
typical = float(typical)
|
||||||
if typical < 0 or typical > 1.0:
|
if typical < 0 or typical > 1.0:
|
||||||
raise ValueError(f"`typical` has to be a float >= 0 and <= 1, but is {typical}")
|
raise ValueError(f"`typical` has to be a float >= 0 and <= 1, but is {typical}")
|
||||||
|
|
Loading…
Reference in New Issue