diff --git a/warpers.py b/warpers.py index bedd3445..0eda0eda 100644 --- a/warpers.py +++ b/warpers.py @@ -105,7 +105,7 @@ class TypicalLogitsWarper(LogitsWarper): Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf ''' - def __init__(self, typical: float, filter_value: -float("Inf"), min_tokens_to_keep: int = 1): + def __init__(self, typical: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): typical = float(typical) if typical < 0 or typical > 1.0: raise ValueError(f"`typical` has to be a float >= 0 and <= 1, but is {typical}")