Undo pretty code because I haven't cracked the jax enigma yet

This commit is contained in:
onesome
2023-04-25 19:54:58 -05:00
parent 1db9d9ba61
commit d496e861f4
2 changed files with 182 additions and 178 deletions

View File

@@ -226,21 +226,185 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra
# probability distribution)
return jax.random.categorical(key, logits, -1).astype(np.uint32)
def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
def kobold_sample_static(
key,
logits,
rpargs,
sampler_order: Optional[np.ndarray] = None,
top_p=0.9,
temp=0.5,
top_k=0,
tfs=1.0,
typical=1.0,
top_a=0.0,
):
'''
This gets called by generate_loop_fn to apply a series of 6 filters
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
before picking one token using the modified logits
'''
# Lame to have these here instead of modeling/warpers.py but JAX JIT stuff >:(
# For documentation see modeling/warpers.py
def sample_top_k(scores: jnp.array) -> jnp.array:
sorted_indices_to_remove = jnp.arange(len(scores)) >= top_k
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(-scores),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, scores)
def sample_top_a(scores: jnp.array) -> jnp.array:
probabilities = jax.nn.softmax(scores)
probs_max = probabilities.max()
return jnp.where(
probabilities < probs_max * probs_max * top_a, -jnp.inf, scores
)
def sample_top_p(scores: jnp.array) -> jnp.array:
sorted_logits = -jnp.sort(-scores)
probabilities = jax.nn.softmax(sorted_logits)
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1)
sorted_indices_to_remove = cumulative_probabilities > top_p
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(-scores),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, scores)
def sample_tail_free(scores: jnp.array) -> jnp.array:
sorted_logits = -jnp.sort(-scores)
probabilities = jax.nn.softmax(sorted_logits)
d2 = jnp.diff(jnp.diff(probabilities))
d2 = jnp.abs(d2)
d2 = d2 / d2.sum(axis=-1, keepdims=True)
cumulative_d2 = jnp.cumsum(d2, axis=-1)
sorted_indices_to_remove = cumulative_d2 > tfs
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
sorted_indices_to_remove = jnp.pad(
sorted_indices_to_remove,
(0, 2),
constant_values=True,
)
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(-scores),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, scores)
def sample_typical(scores: jnp.array) -> jnp.array:
probs = jax.nn.softmax(scores)
log_probs = jnp.log(probs)
neg_entropy = jnp.nansum(probs * log_probs, axis=-1, keepdims=True)
entropy_deviation = jnp.abs(neg_entropy - log_probs)
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
sorted_indices_to_remove = jnp.cumsum(sorted_logits, axis=-1) >= typical
sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=-1)
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(entropy_deviation),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, scores)
def sample_temperature(scores: jnp.array) -> jnp.array:
return scores / temp
def sample_repetition_penalty(
logits: jnp.array,
tokens: jnp.array,
repetition_penalty,
generated_index,
rpslope,
rprange
) -> jnp.array:
"""
This gets called to apply repetition penalty to the 1D array logits
using the provided 1D array of tokens to penalize
"""
rpslope = jnp.int32(rpslope)
rprange = jnp.int32(rprange)
clipped_rprange = jax.lax.cond(
rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange
)
penalty_arange = jnp.roll(
jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]),
generated_index,
axis=-1,
)
# Make a new array with the same length as the tokens array but with
# each element replaced by the value at the corresponding index in the
# logits array; e.g.
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
penalty_logits = jnp.take(logits, tokens)
# Repetition penalty slope
def apply_slope(carry):
repetition_penalty, rprange = carry
_penalty = (penalty_arange / (rprange - 1)) * 2 - 1
_penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1))
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
return _penalty
repetition_penalty = jax.lax.cond(
(rpslope != 0.0)
& (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash
apply_slope,
lambda carry: jnp.full(tokens.shape, carry[0]),
(repetition_penalty, rprange),
)
# Divide positive values by repetition_penalty and multiply negative
# values by repetition_penalty (the academic publication that described
# this technique actually just only divided, but that would cause tokens
# with negative logits to become more likely, which is obviously wrong)
if koboldai_vars.use_alt_rep_pen:
penalty_logits = jnp.where(
penalty_arange >= 0,
penalty_logits - jnp.log(repetition_penalty),
penalty_logits,
)
else:
penalty_logits = jnp.where(
penalty_arange >= 0,
jnp.where(
penalty_logits > 0,
penalty_logits / repetition_penalty,
penalty_logits * repetition_penalty,
),
penalty_logits,
)
# Finally, put those penalized logit values back into their original
# positions in the logits array
return logits.at[tokens].set(penalty_logits)
for k in sampler_order:
logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), warpers.TopK.jax_static, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), warpers.TopA.jax_static, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), warpers.TopP.jax_static, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), warpers.TailFree.jax_static, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), warpers.Typical.jax_static, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), warpers.Temperature.jax_static, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: warpers.RepetitionPenalty.jax_static(*x), lambda x: x[0], (logits, *rpargs))
logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), sample_top_k, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), sample_top_a, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), sample_top_p, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), sample_tail_free, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), sample_typical, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), sample_temperature, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: sample_repetition_penalty(*x), lambda x: x[0], (logits, *rpargs))
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)
pad_token_id = 50256
@@ -357,11 +521,10 @@ class PenalizingCausalTransformer(CausalTransformer):
logits,
(
generated,
# repetition_penalty,
repetition_penalty,
generated_index,
# gen_length,
# rpslope,
# rprange,
rpslope,
rprange,
),
**sampler_options,
)