mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Typical sampling
This commit is contained in:
52
warpers.py
52
warpers.py
@@ -62,7 +62,7 @@ class TailFreeLogitsWarper(LogitsWarper):
|
||||
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
tfs = float(tfs)
|
||||
if tfs < 0 or tfs > 1.0:
|
||||
raise ValueError(f"`tfs` has to be a float > 0 and < 1, but is {tfs}")
|
||||
raise ValueError(f"`tfs` has to be a float >= 0 and <= 1, but is {tfs}")
|
||||
self.tfs = tfs
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
@@ -98,3 +98,53 @@ class TailFreeLogitsWarper(LogitsWarper):
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
|
||||
class TypicalLogitsWarper(LogitsWarper):
|
||||
'''
|
||||
Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf
|
||||
'''
|
||||
|
||||
def __init__(self, typical: float, filter_value: -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
typical = float(typical)
|
||||
if typical < 0 or typical > 1.0:
|
||||
raise ValueError(f"`typical` has to be a float >= 0 and <= 1, but is {typical}")
|
||||
self.typical = typical
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.filter_value >= 1.0:
|
||||
return scores
|
||||
|
||||
# Compute softmax probabilities and the natural logarithms of them
|
||||
probs = scores.softmax(dim=-1)
|
||||
log_probs = probs.log()
|
||||
|
||||
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
|
||||
# in the set of softmax probabilities of the logits
|
||||
neg_entropy = (probs * log_probs).sum(dim=-1, keepdim=True)
|
||||
|
||||
# Determine absolute difference between the negative entropy and the
|
||||
# log probabilities
|
||||
entropy_deviation = (neg_entropy - log_probs).abs()
|
||||
|
||||
# Keep certain tokens such that the sum of the entropy_deviation of the
|
||||
# kept tokens is the smallest possible value such that the sum of the
|
||||
# softmax probabilities of the kept tokens is at least the threshold
|
||||
# value (by sorting the tokens in ascending order of entropy_deviation
|
||||
# and then keeping the smallest possible number of tokens from the
|
||||
# beginning such that sum of softmax probabilities is at or above the
|
||||
# threshold)
|
||||
_, sorted_indices = torch.sort(entropy_deviation)
|
||||
sorted_logits = probs.gather(-1, sorted_indices)
|
||||
sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= self.typical
|
||||
sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dim=-1)
|
||||
|
||||
min_tokens_to_keep = max(self.min_tokens_to_keep, 1)
|
||||
# Keep at least min_tokens_to_keep
|
||||
sorted_indices_to_remove[..., : min_tokens_to_keep] = 0
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
Reference in New Issue
Block a user