diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 19296e0a..effb3de0 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -498,7 +498,7 @@ def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits) logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits) logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits) - logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), apply_repetition_penalty_static, lambda x, *_: x, logits, *rpargs) + logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: apply_repetition_penalty_static(*x), lambda x: x[0], (logits, *rpargs)) # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution)