Typical sampling needs to use nansum instead of sum

If `probs` is zero then `log_probs` will be negative infinity, and the
calculation of `neg_entropy` would then give NaN because zero times
infinity is a mathematically indeterminate value.

We need to use nansum so that those NaN values are treated as zeros to
ignore them in the entropy calculation.
This commit is contained in:
Gnome Ann
2022-03-28 00:02:31 -04:00
parent 77ae893f4d
commit 67e28d2b5c
2 changed files with 3 additions and 3 deletions

View File

@@ -123,7 +123,7 @@ class TypicalLogitsWarper(LogitsWarper):
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
# in the set of softmax probabilities of the logits
neg_entropy = (probs * log_probs).sum(dim=-1, keepdim=True)
neg_entropy = (probs * log_probs).nansum(dim=-1, keepdim=True)
# Determine absolute difference between the negative entropy and the
# log probabilities