mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Typical sampling
This commit is contained in:
@ -67,6 +67,7 @@ def settings_callback() -> dict:
|
||||
"temp": 0.5,
|
||||
"top_k": 0,
|
||||
"tfs": 1.0,
|
||||
"typical": 1.0,
|
||||
"repetition_penalty": 1.0,
|
||||
"rpslope": 0.0,
|
||||
"rprange": 0,
|
||||
@ -155,11 +156,11 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat
|
||||
logits[tokens] = penalty_logits
|
||||
return logits
|
||||
|
||||
def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
||||
def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0):
|
||||
'''
|
||||
This gets called by generate_loop_fn to apply a series of 4 filters
|
||||
to the logits (top-k, then top-p, then TFS, then temperature) before
|
||||
picking one token using the modified logits
|
||||
This gets called by generate_loop_fn to apply a series of 5 filters
|
||||
to the logits (top-k, then top-p, then TFS, then typical, then temperature)
|
||||
before picking one token using the modified logits
|
||||
'''
|
||||
# Top-k (keep only the k tokens with the highest logits and remove
|
||||
# the rest, by setting their logits to negative infinity)
|
||||
@ -246,6 +247,36 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
||||
return np.where(indices_to_remove, -np.inf, logits)
|
||||
if tfs < 1.0:
|
||||
logits = tail_free_filter(logits)
|
||||
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf
|
||||
def typical_filter(logits):
|
||||
# Compute softmax probabilities and the natural logarithms of them
|
||||
probs = jax.nn.softmax(logits)
|
||||
log_probs = np.log(probs)
|
||||
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
|
||||
# in the set of softmax probabilities of the logits
|
||||
neg_entropy = (probs * log_probs).sum(axis=-1, keepdims=True)
|
||||
# Determine absolute difference between the negative entropy and the
|
||||
# log probabilities
|
||||
entropy_deviation = np.abs(neg_entropy - log_probs)
|
||||
# Keep certain tokens such that the sum of the entropy_deviation of the
|
||||
# kept tokens is the smallest possible value such that the sum of the
|
||||
# softmax probabilities of the kept tokens is at least the threshold
|
||||
# value (by sorting the tokens in ascending order of entropy_deviation
|
||||
# and then keeping the smallest possible number of tokens from the
|
||||
# beginning such that sum of softmax probabilities is at or above the
|
||||
# threshold)
|
||||
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
|
||||
sorted_indices_to_remove = np.cumsum(sorted_logits, axis=-1) >= typical
|
||||
sorted_indices_to_remove = np.roll(sorted_indices_to_remove, 1, axis=-1)
|
||||
sorted_indices_to_remove[0] = False
|
||||
# Unsort and remove
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(entropy_deviation),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return np.where(indices_to_remove, -jnp.inf, logits)
|
||||
if typical < 1.0:
|
||||
logits = typical_filter(logits)
|
||||
# Temperature (just divide the logits by the temperature)
|
||||
logits /= temp
|
||||
# Finally, pick one token using the softmax thingy again (it gives
|
||||
@ -298,11 +329,11 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate
|
||||
# positions in the logits array
|
||||
return logits.at[tokens].set(penalty_logits)
|
||||
|
||||
def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
||||
def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0):
|
||||
'''
|
||||
This gets called by generate_loop_fn to apply a series of 4 filters
|
||||
to the logits (top-k, then top-p, then TFS, then temperature) before
|
||||
picking one token using the modified logits
|
||||
This gets called by generate_loop_fn to apply a series of 5 filters
|
||||
to the logits (top-k, then top-p, then TFS, then typical, then temperature)
|
||||
before picking one token using the modified logits
|
||||
'''
|
||||
# Top-k (keep only the k tokens with the highest logits and remove
|
||||
# the rest, by setting their logits to negative infinity)
|
||||
@ -386,6 +417,35 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
||||
logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits)
|
||||
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf
|
||||
def typical_filter(logits):
|
||||
# Compute softmax probabilities and the natural logarithms of them
|
||||
probs = jax.nn.softmax(logits)
|
||||
log_probs = jnp.log(probs)
|
||||
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
|
||||
# in the set of softmax probabilities of the logits
|
||||
neg_entropy = (probs * log_probs).sum(axis=-1, keepdims=True)
|
||||
# Determine absolute difference between the negative entropy and the
|
||||
# log probabilities
|
||||
entropy_deviation = jnp.abs(neg_entropy - log_probs)
|
||||
# Keep certain tokens such that the sum of the entropy_deviation of the
|
||||
# kept tokens is the smallest possible value such that the sum of the
|
||||
# softmax probabilities of the kept tokens is at least the threshold
|
||||
# value (by sorting the tokens in ascending order of entropy_deviation
|
||||
# and then keeping the smallest possible number of tokens from the
|
||||
# beginning such that sum of softmax probabilities is at or above the
|
||||
# threshold)
|
||||
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
|
||||
sorted_indices_to_remove = jnp.cumsum(sorted_logits, axis=-1) >= typical
|
||||
sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=-1)
|
||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||
# Unsort and remove
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(entropy_deviation),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
||||
logits = jax.lax.cond(typical < 1.0, typical_filter, lambda x: x, logits)
|
||||
# Temperature (just divide the logits by the temperature)
|
||||
def temp_filter(logits):
|
||||
return logits / temp
|
||||
@ -742,6 +802,7 @@ def infer_static(
|
||||
temp=0.5,
|
||||
top_k=0,
|
||||
tfs=1.0,
|
||||
typical=1.0,
|
||||
repetition_penalty=1.0,
|
||||
rpslope=0.0,
|
||||
rprange=0,
|
||||
@ -764,6 +825,7 @@ def infer_static(
|
||||
"temp": temp * np.ones(total_batch),
|
||||
"top_p": top_p * np.ones(total_batch),
|
||||
"tfs": tfs * np.ones(total_batch),
|
||||
"typical": typical * np.ones(total_batch),
|
||||
"repetition_penalty": repetition_penalty * np.ones(total_batch),
|
||||
"rpslope": rpslope * np.ones(total_batch),
|
||||
"rprange": np.full(total_batch, rprange, dtype=np.uint32),
|
||||
|
Reference in New Issue
Block a user