Undo pretty code because I haven't cracked the jax enigma yet

This commit is contained in:
onesome
2023-04-25 19:54:58 -05:00
parent 1db9d9ba61
commit d496e861f4
2 changed files with 182 additions and 178 deletions

View File

@@ -68,13 +68,16 @@ class Warper:
All Warpers should be singletons defined in the warpers.py file.
To make a new warper/sampler:
- Create your class, implementing `torch()`, `jax_dynamic`, `jax_static`,
and `value_is_valid()`. Dynamic and static methods are seperated for Jax
- Create your class, implementing `torch()`, `jax_dynamic`, and
`value_is_valid()`. Dynamic and static methods are seperated for Jax
due to how it does JIT compilation of functions (from what I gather).
These `static` methods are very picky about what you can and can't do
These static methods are very picky about what you can and can't do
with data at runtime and thus sometimes need to be implemented
differently than the `dynamic` methods, which are more like the Torch
methods.
- Implement your TPU static function (if applicable) in
tpu_mtj_backend.py->kobold_sample_static. If you do not do this, your
sampler will only work on the much slower dynamic generation mode on TPU.
- Add it to Warper.from_id and tpu_mtj_backend.kobold_sample_static.
- Add it to the UI/sampler_order.
@@ -103,10 +106,6 @@ class Warper:
def jax_dynamic(cls, scores: np.array) -> np.array:
raise NotImplementedError("Please override `jax_dynamic()`.")
@classmethod
def jax_static(cls, scores: jnp.array) -> jnp.array:
raise NotImplementedError("Please override `jax_static()`.")
@classmethod
def value_is_valid(cls) -> bool:
raise NotImplementedError("Please override `value_is_valid()`.")
@@ -125,10 +124,6 @@ class Temperature(Warper):
def jax_dynamic(cls, scores: np.array) -> np.array:
return scores / cls.temperature
@classmethod
def jax_static(cls, scores: jnp.array) -> jnp.array:
return scores / cls.temperature
@classmethod
def value_is_valid(cls) -> bool:
return cls.temperature != 1.0
@@ -181,30 +176,6 @@ class TopP(Warper):
)
return np.where(indices_to_remove, -np.inf, scores)
@classmethod
def jax_static(cls, scores: jnp.array) -> jnp.array:
# Sort the logits array in descending order, replace every element
# with e (Euler's number) to the power of that element, and divide
# each element of the new array by the sum of the elements in the
# new array
sorted_logits = -jnp.sort(-scores)
probabilities = jax.nn.softmax(sorted_logits)
# Calculate cumulative_probabilities as the prefix-sum array of
# probabilities
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1)
# We want to remove tokens with cumulative probability higher
# than top_p
sorted_indices_to_remove = cumulative_probabilities > cls.top_p
# Don't ever remove the token with the highest logit, even if
# the probability is higher than top_p
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
# Unsort and remove
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(-scores),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, scores)
@classmethod
def value_is_valid(cls) -> bool:
return cls.top_p < 1.0
@@ -242,16 +213,6 @@ class TopK(Warper):
)
return np.where(indices_to_remove, -np.inf, scores)
@classmethod
def jax_static(cls, scores: jnp.array) -> jnp.array:
sorted_indices_to_remove = jnp.arange(len(scores)) >= cls.top_k
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(-scores),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, scores)
@classmethod
def value_is_valid(cls) -> bool:
return cls.top_k > 0
@@ -339,30 +300,6 @@ class TailFree(Warper):
)
return np.where(indices_to_remove, -np.inf, scores)
@classmethod
def jax_static(cls, scores: jnp.array) -> jnp.array:
sorted_logits = -jnp.sort(-scores)
probabilities = jax.nn.softmax(sorted_logits)
d2 = jnp.diff(jnp.diff(probabilities))
d2 = jnp.abs(d2)
d2 = d2 / d2.sum(axis=-1, keepdims=True)
cumulative_d2 = jnp.cumsum(d2, axis=-1)
sorted_indices_to_remove = cumulative_d2 > cls.tfs
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
sorted_indices_to_remove = jnp.pad(
sorted_indices_to_remove,
(0, 2),
constant_values=True,
)
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(-scores),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, scores)
@classmethod
def value_is_valid(cls) -> bool:
return cls.tfs < 1.0
@@ -443,25 +380,6 @@ class Typical(Warper):
)
return np.where(indices_to_remove, -jnp.inf, scores)
@classmethod
def jax_static(cls, scores: jnp.array) -> jnp.array:
probs = jax.nn.softmax(scores)
log_probs = jnp.log(probs)
neg_entropy = jnp.nansum(probs * log_probs, axis=-1, keepdims=True)
entropy_deviation = jnp.abs(neg_entropy - log_probs)
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
sorted_indices_to_remove = jnp.cumsum(sorted_logits, axis=-1) >= cls.typical
sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=-1)
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(entropy_deviation),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, scores)
@classmethod
def value_is_valid(cls) -> bool:
return cls.typical < 1.0
@@ -504,14 +422,6 @@ class TopA(Warper):
probabilities < probs_max * probs_max * cls.top_a, -np.inf, scores
)
@classmethod
def jax_static(cls, scores: jnp.array) -> jnp.array:
probabilities = jax.nn.softmax(scores)
probs_max = probabilities.max()
return jnp.where(
probabilities < probs_max * probs_max * cls.top_a, -jnp.inf, scores
)
@classmethod
def value_is_valid(cls) -> bool:
return cls.top_a > 0.0
@@ -557,75 +467,6 @@ class RepetitionPenalty(Warper):
return scores
@classmethod
def jax_static(
cls,
logits: jnp.array,
tokens: jnp.array,
generated_index,
) -> jnp.array:
"""
This gets called to apply repetition penalty to the 1D array logits
using the provided 1D array of tokens to penalize
"""
rpslope = jnp.int32(cls.rep_pen_slope)
rprange = jnp.int32(cls.rep_pen_range)
repetition_penalty = cls.rep_pen
clipped_rprange = jax.lax.cond(
rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange
)
penalty_arange = jnp.roll(
jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]),
generated_index,
axis=-1,
)
# Make a new array with the same length as the tokens array but with
# each element replaced by the value at the corresponding index in the
# logits array; e.g.
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
penalty_logits = jnp.take(logits, tokens)
# Repetition penalty slope
def apply_slope(carry):
repetition_penalty, rprange = carry
_penalty = (penalty_arange / (rprange - 1)) * 2 - 1
_penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1))
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
return _penalty
repetition_penalty = jax.lax.cond(
(rpslope != 0.0)
& (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash
apply_slope,
lambda carry: jnp.full(tokens.shape, carry[0]),
(repetition_penalty, rprange),
)
# Divide positive values by repetition_penalty and multiply negative
# values by repetition_penalty (the academic publication that described
# this technique actually just only divided, but that would cause tokens
# with negative logits to become more likely, which is obviously wrong)
if cls.use_alt_rep_pen:
penalty_logits = jnp.where(
penalty_arange >= 0,
penalty_logits - jnp.log(repetition_penalty),
penalty_logits,
)
else:
penalty_logits = jnp.where(
penalty_arange >= 0,
jnp.where(
penalty_logits > 0,
penalty_logits / repetition_penalty,
penalty_logits * repetition_penalty,
),
penalty_logits,
)
# Finally, put those penalized logit values back into their original
# positions in the logits array
return logits.at[tokens].set(penalty_logits)
@classmethod
def jax_dynamic(
cls,

View File

@@ -226,21 +226,185 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra
# probability distribution)
return jax.random.categorical(key, logits, -1).astype(np.uint32)
def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
def kobold_sample_static(
key,
logits,
rpargs,
sampler_order: Optional[np.ndarray] = None,
top_p=0.9,
temp=0.5,
top_k=0,
tfs=1.0,
typical=1.0,
top_a=0.0,
):
'''
This gets called by generate_loop_fn to apply a series of 6 filters
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
before picking one token using the modified logits
'''
# Lame to have these here instead of modeling/warpers.py but JAX JIT stuff >:(
# For documentation see modeling/warpers.py
def sample_top_k(scores: jnp.array) -> jnp.array:
sorted_indices_to_remove = jnp.arange(len(scores)) >= top_k
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(-scores),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, scores)
def sample_top_a(scores: jnp.array) -> jnp.array:
probabilities = jax.nn.softmax(scores)
probs_max = probabilities.max()
return jnp.where(
probabilities < probs_max * probs_max * top_a, -jnp.inf, scores
)
def sample_top_p(scores: jnp.array) -> jnp.array:
sorted_logits = -jnp.sort(-scores)
probabilities = jax.nn.softmax(sorted_logits)
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1)
sorted_indices_to_remove = cumulative_probabilities > top_p
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(-scores),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, scores)
def sample_tail_free(scores: jnp.array) -> jnp.array:
sorted_logits = -jnp.sort(-scores)
probabilities = jax.nn.softmax(sorted_logits)
d2 = jnp.diff(jnp.diff(probabilities))
d2 = jnp.abs(d2)
d2 = d2 / d2.sum(axis=-1, keepdims=True)
cumulative_d2 = jnp.cumsum(d2, axis=-1)
sorted_indices_to_remove = cumulative_d2 > tfs
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
sorted_indices_to_remove = jnp.pad(
sorted_indices_to_remove,
(0, 2),
constant_values=True,
)
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(-scores),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, scores)
def sample_typical(scores: jnp.array) -> jnp.array:
probs = jax.nn.softmax(scores)
log_probs = jnp.log(probs)
neg_entropy = jnp.nansum(probs * log_probs, axis=-1, keepdims=True)
entropy_deviation = jnp.abs(neg_entropy - log_probs)
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
sorted_indices_to_remove = jnp.cumsum(sorted_logits, axis=-1) >= typical
sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=-1)
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(entropy_deviation),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, scores)
def sample_temperature(scores: jnp.array) -> jnp.array:
return scores / temp
def sample_repetition_penalty(
logits: jnp.array,
tokens: jnp.array,
repetition_penalty,
generated_index,
rpslope,
rprange
) -> jnp.array:
"""
This gets called to apply repetition penalty to the 1D array logits
using the provided 1D array of tokens to penalize
"""
rpslope = jnp.int32(rpslope)
rprange = jnp.int32(rprange)
clipped_rprange = jax.lax.cond(
rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange
)
penalty_arange = jnp.roll(
jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]),
generated_index,
axis=-1,
)
# Make a new array with the same length as the tokens array but with
# each element replaced by the value at the corresponding index in the
# logits array; e.g.
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
penalty_logits = jnp.take(logits, tokens)
# Repetition penalty slope
def apply_slope(carry):
repetition_penalty, rprange = carry
_penalty = (penalty_arange / (rprange - 1)) * 2 - 1
_penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1))
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
return _penalty
repetition_penalty = jax.lax.cond(
(rpslope != 0.0)
& (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash
apply_slope,
lambda carry: jnp.full(tokens.shape, carry[0]),
(repetition_penalty, rprange),
)
# Divide positive values by repetition_penalty and multiply negative
# values by repetition_penalty (the academic publication that described
# this technique actually just only divided, but that would cause tokens
# with negative logits to become more likely, which is obviously wrong)
if koboldai_vars.use_alt_rep_pen:
penalty_logits = jnp.where(
penalty_arange >= 0,
penalty_logits - jnp.log(repetition_penalty),
penalty_logits,
)
else:
penalty_logits = jnp.where(
penalty_arange >= 0,
jnp.where(
penalty_logits > 0,
penalty_logits / repetition_penalty,
penalty_logits * repetition_penalty,
),
penalty_logits,
)
# Finally, put those penalized logit values back into their original
# positions in the logits array
return logits.at[tokens].set(penalty_logits)
for k in sampler_order:
logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), warpers.TopK.jax_static, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), warpers.TopA.jax_static, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), warpers.TopP.jax_static, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), warpers.TailFree.jax_static, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), warpers.Typical.jax_static, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), warpers.Temperature.jax_static, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: warpers.RepetitionPenalty.jax_static(*x), lambda x: x[0], (logits, *rpargs))
logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), sample_top_k, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), sample_top_a, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), sample_top_p, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), sample_tail_free, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), sample_typical, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), sample_temperature, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: sample_repetition_penalty(*x), lambda x: x[0], (logits, *rpargs))
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)
pad_token_id = 50256
@@ -357,11 +521,10 @@ class PenalizingCausalTransformer(CausalTransformer):
logits,
(
generated,
# repetition_penalty,
repetition_penalty,
generated_index,
# gen_length,
# rpslope,
# rprange,
rpslope,
rprange,
),
**sampler_options,
)