Hide division by zero warning in JAX typical filter
This warning happens when `np.log` gets an input containing zeros. In that case, NumPy will throw a warning and output negative infinity. Negative infinity is the correct behaviour here, so we can safely ignore the warning.
This commit is contained in:
parent
20e48b11d7
commit
d5989d4c62
|
@ -247,11 +247,12 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
|
|||
return np.where(indices_to_remove, -np.inf, logits)
|
||||
if tfs < 1.0:
|
||||
logits = tail_free_filter(logits)
|
||||
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf
|
||||
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
|
||||
def typical_filter(logits):
|
||||
# Compute softmax probabilities and the natural logarithms of them
|
||||
probs = jax.nn.softmax(logits)
|
||||
log_probs = np.log(probs)
|
||||
with np.errstate(divide="ignore"):
|
||||
log_probs = np.log(probs)
|
||||
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
|
||||
# in the set of softmax probabilities of the logits
|
||||
neg_entropy = (probs * log_probs).sum(axis=-1, keepdims=True)
|
||||
|
@ -417,7 +418,7 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
|
|||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
||||
logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits)
|
||||
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf
|
||||
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
|
||||
def typical_filter(logits):
|
||||
# Compute softmax probabilities and the natural logarithms of them
|
||||
probs = jax.nn.softmax(logits)
|
||||
|
|
Loading…
Reference in New Issue