Repetition penalty is now sampler #6 in the sampler order

This commit is contained in:
vfbd
2022-08-23 15:10:21 -04:00
parent 9eecb61fea
commit 6ffaf43548
3 changed files with 21 additions and 8 deletions

View File

@@ -312,10 +312,10 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra
if k == 3 and tfs < 1.0: logits = tail_free_filter(logits)
if k == 4 and typical < 1.0: logits = typical_filter(logits)
if k == 5 and temp != 1.0: logits = temp_filter(logits)
if k == 6 and rpargs[1] != 1.0: logits = apply_repetition_penalty_dynamic(logits, *rpargs)
# Finally, pick one token using the softmax thingy again (it gives
# an array whose elements sum to 1 so it can be used nicely as a
# probability distribution)
logits = apply_repetition_penalty_dynamic(logits, *rpargs)
return jax.random.categorical(key, logits, -1).astype(np.uint32)
def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
@@ -498,10 +498,10 @@ def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), apply_repetition_penalty_static, lambda x, *_: x, logits, *rpargs)
# Finally, pick one token using the softmax thingy again (it gives
# an array whose elements sum to 1 so it can be used nicely as a
# probability distribution)
logits = apply_repetition_penalty_static(logits, *rpargs)
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)
pad_token_id = 50256
@@ -858,6 +858,9 @@ def infer_static(
maps.thread_resources.env = thread_resources_env
if sampler_order is None:
sampler_order = utils.default_sampler_order.copy()
sampler_order = sampler_order[:]
if len(sampler_order) < 7: # Add repetition penalty at beginning if it's not present
sampler_order = [6] + sampler_order
sampler_order = np.uint32(sampler_order)
total_batch = 1
tokens = context