Typical sampling needs to use nansum instead of sum
If `probs` is zero then `log_probs` will be negative infinity, and the calculation of `neg_entropy` would then give NaN because zero times infinity is a mathematically indeterminate value. We need to use nansum so that those NaN values are treated as zeros to ignore them in the entropy calculation.
This commit is contained in:
parent
77ae893f4d
commit
67e28d2b5c
|
@ -255,7 +255,7 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
|
||||||
log_probs = np.log(probs)
|
log_probs = np.log(probs)
|
||||||
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
|
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
|
||||||
# in the set of softmax probabilities of the logits
|
# in the set of softmax probabilities of the logits
|
||||||
neg_entropy = (probs * log_probs).sum(axis=-1, keepdims=True)
|
neg_entropy = np.nansum(probs * log_probs, axis=-1, keepdims=True)
|
||||||
# Determine absolute difference between the negative entropy and the
|
# Determine absolute difference between the negative entropy and the
|
||||||
# log probabilities
|
# log probabilities
|
||||||
entropy_deviation = np.abs(neg_entropy - log_probs)
|
entropy_deviation = np.abs(neg_entropy - log_probs)
|
||||||
|
@ -425,7 +425,7 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
|
||||||
log_probs = jnp.log(probs)
|
log_probs = jnp.log(probs)
|
||||||
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
|
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
|
||||||
# in the set of softmax probabilities of the logits
|
# in the set of softmax probabilities of the logits
|
||||||
neg_entropy = (probs * log_probs).sum(axis=-1, keepdims=True)
|
neg_entropy = jnp.nansum(probs * log_probs, axis=-1, keepdims=True)
|
||||||
# Determine absolute difference between the negative entropy and the
|
# Determine absolute difference between the negative entropy and the
|
||||||
# log probabilities
|
# log probabilities
|
||||||
entropy_deviation = jnp.abs(neg_entropy - log_probs)
|
entropy_deviation = jnp.abs(neg_entropy - log_probs)
|
||||||
|
|
|
@ -123,7 +123,7 @@ class TypicalLogitsWarper(LogitsWarper):
|
||||||
|
|
||||||
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
|
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
|
||||||
# in the set of softmax probabilities of the logits
|
# in the set of softmax probabilities of the logits
|
||||||
neg_entropy = (probs * log_probs).sum(dim=-1, keepdim=True)
|
neg_entropy = (probs * log_probs).nansum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
# Determine absolute difference between the negative entropy and the
|
# Determine absolute difference between the negative entropy and the
|
||||||
# log probabilities
|
# log probabilities
|
||||||
|
|
Loading…
Reference in New Issue