mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-06 04:14:18 +01:00
Fix jax.lax.cond
call
This commit is contained in:
parent
ff9058896e
commit
938e1eddf3
@ -498,7 +498,7 @@ def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), apply_repetition_penalty_static, lambda x, *_: x, logits, *rpargs)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: apply_repetition_penalty_static(*x), lambda x: x[0], (logits, *rpargs))
|
||||
# Finally, pick one token using the softmax thingy again (it gives
|
||||
# an array whose elements sum to 1 so it can be used nicely as a
|
||||
# probability distribution)
|
||||
|
Loading…
x
Reference in New Issue
Block a user