From d5989d4c62c45e2f170118355543eb8928ac148e Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Sun, 27 Mar 2022 16:57:12 -0400 Subject: [PATCH] Hide division by zero warning in JAX typical filter This warning happens when `np.log` gets an input containing zeros. In that case, NumPy will throw a warning and output negative infinity. Negative infinity is the correct behaviour here, so we can safely ignore the warning. --- tpu_mtj_backend.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 202e24dc..000f1713 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -247,11 +247,12 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty return np.where(indices_to_remove, -np.inf, logits) if tfs < 1.0: logits = tail_free_filter(logits) - # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf + # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf) def typical_filter(logits): # Compute softmax probabilities and the natural logarithms of them probs = jax.nn.softmax(logits) - log_probs = np.log(probs) + with np.errstate(divide="ignore"): + log_probs = np.log(probs) # Compute the negative of entropy, which is the sum of p*ln(p) for all p # in the set of softmax probabilities of the logits neg_entropy = (probs * log_probs).sum(axis=-1, keepdims=True) @@ -417,7 +418,7 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ ) return jnp.where(indices_to_remove, -jnp.inf, logits) logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits) - # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf + # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf) def typical_filter(logits): # Compute softmax probabilities and the natural logarithms of them probs = jax.nn.softmax(logits)