mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Repetition penalty is now sampler #6 in the sampler order
This commit is contained in:
@@ -312,10 +312,10 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra
|
||||
if k == 3 and tfs < 1.0: logits = tail_free_filter(logits)
|
||||
if k == 4 and typical < 1.0: logits = typical_filter(logits)
|
||||
if k == 5 and temp != 1.0: logits = temp_filter(logits)
|
||||
if k == 6 and rpargs[1] != 1.0: logits = apply_repetition_penalty_dynamic(logits, *rpargs)
|
||||
# Finally, pick one token using the softmax thingy again (it gives
|
||||
# an array whose elements sum to 1 so it can be used nicely as a
|
||||
# probability distribution)
|
||||
logits = apply_repetition_penalty_dynamic(logits, *rpargs)
|
||||
return jax.random.categorical(key, logits, -1).astype(np.uint32)
|
||||
|
||||
def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
|
||||
@@ -498,10 +498,10 @@ def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), apply_repetition_penalty_static, lambda x, *_: x, logits, *rpargs)
|
||||
# Finally, pick one token using the softmax thingy again (it gives
|
||||
# an array whose elements sum to 1 so it can be used nicely as a
|
||||
# probability distribution)
|
||||
logits = apply_repetition_penalty_static(logits, *rpargs)
|
||||
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)
|
||||
|
||||
pad_token_id = 50256
|
||||
@@ -858,6 +858,9 @@ def infer_static(
|
||||
maps.thread_resources.env = thread_resources_env
|
||||
if sampler_order is None:
|
||||
sampler_order = utils.default_sampler_order.copy()
|
||||
sampler_order = sampler_order[:]
|
||||
if len(sampler_order) < 7: # Add repetition penalty at beginning if it's not present
|
||||
sampler_order = [6] + sampler_order
|
||||
sampler_order = np.uint32(sampler_order)
|
||||
total_batch = 1
|
||||
tokens = context
|
||||
|
Reference in New Issue
Block a user