mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Top-A sampling
This commit is contained in:
@ -70,6 +70,7 @@ def settings_callback() -> dict:
|
||||
"top_k": 0,
|
||||
"tfs": 1.0,
|
||||
"typical": 1.0,
|
||||
"top_a": 0.0,
|
||||
"repetition_penalty": 1.0,
|
||||
"rpslope": 0.0,
|
||||
"rprange": 0,
|
||||
@ -158,10 +159,10 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat
|
||||
logits[tokens] = penalty_logits
|
||||
return logits
|
||||
|
||||
def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0):
|
||||
def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||
'''
|
||||
This gets called by generate_loop_fn to apply a series of 5 filters
|
||||
to the logits (top-k, then top-p, then TFS, then typical, then temperature)
|
||||
This gets called by generate_loop_fn to apply a series of 6 filters
|
||||
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
||||
before picking one token using the modified logits
|
||||
'''
|
||||
# Top-k (keep only the k tokens with the highest logits and remove
|
||||
@ -182,6 +183,20 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
|
||||
return np.where(indices_to_remove, -np.inf, logits)
|
||||
if top_k > 0:
|
||||
logits = top_k_filter(logits)
|
||||
# Top-a (remove all tokens that have softmax probability less than
|
||||
# a*m^2 where m is the maximum softmax probability)
|
||||
def top_a_filter(logits):
|
||||
# Replace every element in the logits array
|
||||
# with e (Euler's number) to the power of that element, and divide
|
||||
# each element of the new array by the sum of the elements in the
|
||||
# new array
|
||||
probabilities = np.array(jax.nn.softmax(logits), copy=True)
|
||||
# Find the largest probability
|
||||
probs_max = probabilities.max()
|
||||
# Remove tokens
|
||||
return np.where(probabilities < probs_max * probs_max * top_a, -np.inf, logits)
|
||||
if top_a > 0.0:
|
||||
logits = top_a_filter(logits)
|
||||
# Top-p (after sorting the remaining tokens again in descending order of
|
||||
# logit, remove the ones that have cumulative softmax probability
|
||||
# greater than p)
|
||||
@ -332,10 +347,10 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate
|
||||
# positions in the logits array
|
||||
return logits.at[tokens].set(penalty_logits)
|
||||
|
||||
def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0):
|
||||
def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||
'''
|
||||
This gets called by generate_loop_fn to apply a series of 5 filters
|
||||
to the logits (top-k, then top-p, then TFS, then typical, then temperature)
|
||||
This gets called by generate_loop_fn to apply a series of 6 filters
|
||||
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
||||
before picking one token using the modified logits
|
||||
'''
|
||||
# Top-k (keep only the k tokens with the highest logits and remove
|
||||
@ -355,6 +370,19 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
||||
logits = jax.lax.cond(top_k > 0, top_k_filter, lambda x: x, logits)
|
||||
# Top-a (remove all tokens that have softmax probability less than
|
||||
# a*m^2 where m is the maximum softmax probability)
|
||||
def top_a_filter(logits):
|
||||
# Replace every element in the logits array
|
||||
# with e (Euler's number) to the power of that element, and divide
|
||||
# each element of the new array by the sum of the elements in the
|
||||
# new array
|
||||
probabilities = jax.nn.softmax(logits)
|
||||
# Find the largest probability
|
||||
probs_max = probabilities.max()
|
||||
# Remove tokens
|
||||
return jnp.where(probabilities < probs_max * probs_max * top_a, -jnp.inf, logits)
|
||||
logits = jax.lax.cond(top_a > 0.0, top_a_filter, lambda x: x, logits)
|
||||
# Top-p (after sorting the remaining tokens again in descending order of
|
||||
# logit, remove the ones that have cumulative softmax probability
|
||||
# greater than p)
|
||||
@ -806,6 +834,7 @@ def infer_static(
|
||||
top_k=0,
|
||||
tfs=1.0,
|
||||
typical=1.0,
|
||||
top_a=0.0,
|
||||
repetition_penalty=1.0,
|
||||
rpslope=0.0,
|
||||
rprange=0,
|
||||
@ -829,6 +858,7 @@ def infer_static(
|
||||
"top_p": top_p * np.ones(total_batch),
|
||||
"tfs": tfs * np.ones(total_batch),
|
||||
"typical": typical * np.ones(total_batch),
|
||||
"top_a": top_a * np.ones(total_batch),
|
||||
"repetition_penalty": repetition_penalty * np.ones(total_batch),
|
||||
"rpslope": rpslope * np.ones(total_batch),
|
||||
"rprange": np.full(total_batch, rprange, dtype=np.uint32),
|
||||
|
Reference in New Issue
Block a user