Hide division by zero warning in JAX typical filter

This warning happens when `np.log` gets an input containing zeros.
In that case, NumPy will throw a warning and output negative infinity.

Negative infinity is the correct behaviour here, so we can safely ignore
the warning.
This commit is contained in:
Gnome Ann 2022-03-27 16:57:12 -04:00
parent 20e48b11d7
commit d5989d4c62
1 changed files with 4 additions and 3 deletions

View File

@ -247,11 +247,12 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
return np.where(indices_to_remove, -np.inf, logits)
if tfs < 1.0:
logits = tail_free_filter(logits)
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
def typical_filter(logits):
# Compute softmax probabilities and the natural logarithms of them
probs = jax.nn.softmax(logits)
log_probs = np.log(probs)
with np.errstate(divide="ignore"):
log_probs = np.log(probs)
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
# in the set of softmax probabilities of the logits
neg_entropy = (probs * log_probs).sum(axis=-1, keepdims=True)
@ -417,7 +418,7 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
)
return jnp.where(indices_to_remove, -jnp.inf, logits)
logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits)
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
def typical_filter(logits):
# Compute softmax probabilities and the natural logarithms of them
probs = jax.nn.softmax(logits)