From 67e28d2b5ce4e45d060516d1c7dffdc546aeafa3 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Mon, 28 Mar 2022 00:02:31 -0400 Subject: [PATCH] Typical sampling needs to use nansum instead of sum If `probs` is zero then `log_probs` will be negative infinity, and the calculation of `neg_entropy` would then give NaN because zero times infinity is a mathematically indeterminate value. We need to use nansum so that those NaN values are treated as zeros to ignore them in the entropy calculation. --- tpu_mtj_backend.py | 4 ++-- warpers.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 000f1713..c7a8840f 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -255,7 +255,7 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty log_probs = np.log(probs) # Compute the negative of entropy, which is the sum of p*ln(p) for all p # in the set of softmax probabilities of the logits - neg_entropy = (probs * log_probs).sum(axis=-1, keepdims=True) + neg_entropy = np.nansum(probs * log_probs, axis=-1, keepdims=True) # Determine absolute difference between the negative entropy and the # log probabilities entropy_deviation = np.abs(neg_entropy - log_probs) @@ -425,7 +425,7 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ log_probs = jnp.log(probs) # Compute the negative of entropy, which is the sum of p*ln(p) for all p # in the set of softmax probabilities of the logits - neg_entropy = (probs * log_probs).sum(axis=-1, keepdims=True) + neg_entropy = jnp.nansum(probs * log_probs, axis=-1, keepdims=True) # Determine absolute difference between the negative entropy and the # log probabilities entropy_deviation = jnp.abs(neg_entropy - log_probs) diff --git a/warpers.py b/warpers.py index bb12cbb0..7c4f854b 100644 --- a/warpers.py +++ b/warpers.py @@ -123,7 +123,7 @@ class TypicalLogitsWarper(LogitsWarper): # Compute the negative of entropy, which is the sum of p*ln(p) for all p # in the set of softmax probabilities of the logits - neg_entropy = (probs * log_probs).sum(dim=-1, keepdim=True) + neg_entropy = (probs * log_probs).nansum(dim=-1, keepdim=True) # Determine absolute difference between the negative entropy and the # log probabilities