Merge pull request #107 from VE-FORBRYDERNE/typical

Typical sampling needs to use nansum instead of sum
This commit is contained in:
henk717 2022-03-28 11:13:38 +02:00 committed by GitHub
commit 8368b20421
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 3 deletions

View File

@ -255,7 +255,7 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
log_probs = np.log(probs)
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
# in the set of softmax probabilities of the logits
neg_entropy = (probs * log_probs).sum(axis=-1, keepdims=True)
neg_entropy = np.nansum(probs * log_probs, axis=-1, keepdims=True)
# Determine absolute difference between the negative entropy and the
# log probabilities
entropy_deviation = np.abs(neg_entropy - log_probs)
@ -425,7 +425,7 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
log_probs = jnp.log(probs)
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
# in the set of softmax probabilities of the logits
neg_entropy = (probs * log_probs).sum(axis=-1, keepdims=True)
neg_entropy = jnp.nansum(probs * log_probs, axis=-1, keepdims=True)
# Determine absolute difference between the negative entropy and the
# log probabilities
entropy_deviation = jnp.abs(neg_entropy - log_probs)

View File

@ -123,7 +123,7 @@ class TypicalLogitsWarper(LogitsWarper):
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
# in the set of softmax probabilities of the logits
neg_entropy = (probs * log_probs).sum(dim=-1, keepdim=True)
neg_entropy = (probs * log_probs).nansum(dim=-1, keepdim=True)
# Determine absolute difference between the negative entropy and the
# log probabilities