mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Undo pretty code because I haven't cracked the jax enigma yet
This commit is contained in:
@@ -68,13 +68,16 @@ class Warper:
|
||||
All Warpers should be singletons defined in the warpers.py file.
|
||||
|
||||
To make a new warper/sampler:
|
||||
- Create your class, implementing `torch()`, `jax_dynamic`, `jax_static`,
|
||||
and `value_is_valid()`. Dynamic and static methods are seperated for Jax
|
||||
- Create your class, implementing `torch()`, `jax_dynamic`, and
|
||||
`value_is_valid()`. Dynamic and static methods are seperated for Jax
|
||||
due to how it does JIT compilation of functions (from what I gather).
|
||||
These `static` methods are very picky about what you can and can't do
|
||||
These static methods are very picky about what you can and can't do
|
||||
with data at runtime and thus sometimes need to be implemented
|
||||
differently than the `dynamic` methods, which are more like the Torch
|
||||
methods.
|
||||
- Implement your TPU static function (if applicable) in
|
||||
tpu_mtj_backend.py->kobold_sample_static. If you do not do this, your
|
||||
sampler will only work on the much slower dynamic generation mode on TPU.
|
||||
- Add it to Warper.from_id and tpu_mtj_backend.kobold_sample_static.
|
||||
- Add it to the UI/sampler_order.
|
||||
|
||||
@@ -103,10 +106,6 @@ class Warper:
|
||||
def jax_dynamic(cls, scores: np.array) -> np.array:
|
||||
raise NotImplementedError("Please override `jax_dynamic()`.")
|
||||
|
||||
@classmethod
|
||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
raise NotImplementedError("Please override `jax_static()`.")
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
raise NotImplementedError("Please override `value_is_valid()`.")
|
||||
@@ -125,10 +124,6 @@ class Temperature(Warper):
|
||||
def jax_dynamic(cls, scores: np.array) -> np.array:
|
||||
return scores / cls.temperature
|
||||
|
||||
@classmethod
|
||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
return scores / cls.temperature
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.temperature != 1.0
|
||||
@@ -181,30 +176,6 @@ class TopP(Warper):
|
||||
)
|
||||
return np.where(indices_to_remove, -np.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
# Sort the logits array in descending order, replace every element
|
||||
# with e (Euler's number) to the power of that element, and divide
|
||||
# each element of the new array by the sum of the elements in the
|
||||
# new array
|
||||
sorted_logits = -jnp.sort(-scores)
|
||||
probabilities = jax.nn.softmax(sorted_logits)
|
||||
# Calculate cumulative_probabilities as the prefix-sum array of
|
||||
# probabilities
|
||||
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1)
|
||||
# We want to remove tokens with cumulative probability higher
|
||||
# than top_p
|
||||
sorted_indices_to_remove = cumulative_probabilities > cls.top_p
|
||||
# Don't ever remove the token with the highest logit, even if
|
||||
# the probability is higher than top_p
|
||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||
# Unsort and remove
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(-scores),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.top_p < 1.0
|
||||
@@ -242,16 +213,6 @@ class TopK(Warper):
|
||||
)
|
||||
return np.where(indices_to_remove, -np.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
sorted_indices_to_remove = jnp.arange(len(scores)) >= cls.top_k
|
||||
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(-scores),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.top_k > 0
|
||||
@@ -339,30 +300,6 @@ class TailFree(Warper):
|
||||
)
|
||||
return np.where(indices_to_remove, -np.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
sorted_logits = -jnp.sort(-scores)
|
||||
probabilities = jax.nn.softmax(sorted_logits)
|
||||
|
||||
d2 = jnp.diff(jnp.diff(probabilities))
|
||||
d2 = jnp.abs(d2)
|
||||
d2 = d2 / d2.sum(axis=-1, keepdims=True)
|
||||
|
||||
cumulative_d2 = jnp.cumsum(d2, axis=-1)
|
||||
sorted_indices_to_remove = cumulative_d2 > cls.tfs
|
||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||
sorted_indices_to_remove = jnp.pad(
|
||||
sorted_indices_to_remove,
|
||||
(0, 2),
|
||||
constant_values=True,
|
||||
)
|
||||
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(-scores),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.tfs < 1.0
|
||||
@@ -443,25 +380,6 @@ class Typical(Warper):
|
||||
)
|
||||
return np.where(indices_to_remove, -jnp.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
probs = jax.nn.softmax(scores)
|
||||
log_probs = jnp.log(probs)
|
||||
|
||||
neg_entropy = jnp.nansum(probs * log_probs, axis=-1, keepdims=True)
|
||||
entropy_deviation = jnp.abs(neg_entropy - log_probs)
|
||||
|
||||
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
|
||||
sorted_indices_to_remove = jnp.cumsum(sorted_logits, axis=-1) >= cls.typical
|
||||
sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=-1)
|
||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(entropy_deviation),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.typical < 1.0
|
||||
@@ -504,14 +422,6 @@ class TopA(Warper):
|
||||
probabilities < probs_max * probs_max * cls.top_a, -np.inf, scores
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
probabilities = jax.nn.softmax(scores)
|
||||
probs_max = probabilities.max()
|
||||
return jnp.where(
|
||||
probabilities < probs_max * probs_max * cls.top_a, -jnp.inf, scores
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.top_a > 0.0
|
||||
@@ -557,75 +467,6 @@ class RepetitionPenalty(Warper):
|
||||
|
||||
return scores
|
||||
|
||||
@classmethod
|
||||
def jax_static(
|
||||
cls,
|
||||
logits: jnp.array,
|
||||
tokens: jnp.array,
|
||||
generated_index,
|
||||
) -> jnp.array:
|
||||
"""
|
||||
This gets called to apply repetition penalty to the 1D array logits
|
||||
using the provided 1D array of tokens to penalize
|
||||
"""
|
||||
rpslope = jnp.int32(cls.rep_pen_slope)
|
||||
rprange = jnp.int32(cls.rep_pen_range)
|
||||
repetition_penalty = cls.rep_pen
|
||||
|
||||
clipped_rprange = jax.lax.cond(
|
||||
rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange
|
||||
)
|
||||
|
||||
penalty_arange = jnp.roll(
|
||||
jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]),
|
||||
generated_index,
|
||||
axis=-1,
|
||||
)
|
||||
# Make a new array with the same length as the tokens array but with
|
||||
# each element replaced by the value at the corresponding index in the
|
||||
# logits array; e.g.
|
||||
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
|
||||
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
|
||||
penalty_logits = jnp.take(logits, tokens)
|
||||
# Repetition penalty slope
|
||||
def apply_slope(carry):
|
||||
repetition_penalty, rprange = carry
|
||||
_penalty = (penalty_arange / (rprange - 1)) * 2 - 1
|
||||
_penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1))
|
||||
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
|
||||
return _penalty
|
||||
|
||||
repetition_penalty = jax.lax.cond(
|
||||
(rpslope != 0.0)
|
||||
& (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash
|
||||
apply_slope,
|
||||
lambda carry: jnp.full(tokens.shape, carry[0]),
|
||||
(repetition_penalty, rprange),
|
||||
)
|
||||
# Divide positive values by repetition_penalty and multiply negative
|
||||
# values by repetition_penalty (the academic publication that described
|
||||
# this technique actually just only divided, but that would cause tokens
|
||||
# with negative logits to become more likely, which is obviously wrong)
|
||||
if cls.use_alt_rep_pen:
|
||||
penalty_logits = jnp.where(
|
||||
penalty_arange >= 0,
|
||||
penalty_logits - jnp.log(repetition_penalty),
|
||||
penalty_logits,
|
||||
)
|
||||
else:
|
||||
penalty_logits = jnp.where(
|
||||
penalty_arange >= 0,
|
||||
jnp.where(
|
||||
penalty_logits > 0,
|
||||
penalty_logits / repetition_penalty,
|
||||
penalty_logits * repetition_penalty,
|
||||
),
|
||||
penalty_logits,
|
||||
)
|
||||
# Finally, put those penalized logit values back into their original
|
||||
# positions in the logits array
|
||||
return logits.at[tokens].set(penalty_logits)
|
||||
|
||||
@classmethod
|
||||
def jax_dynamic(
|
||||
cls,
|
||||
|
@@ -226,21 +226,185 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra
|
||||
# probability distribution)
|
||||
return jax.random.categorical(key, logits, -1).astype(np.uint32)
|
||||
|
||||
|
||||
def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||
def kobold_sample_static(
|
||||
key,
|
||||
logits,
|
||||
rpargs,
|
||||
sampler_order: Optional[np.ndarray] = None,
|
||||
top_p=0.9,
|
||||
temp=0.5,
|
||||
top_k=0,
|
||||
tfs=1.0,
|
||||
typical=1.0,
|
||||
top_a=0.0,
|
||||
):
|
||||
'''
|
||||
This gets called by generate_loop_fn to apply a series of 6 filters
|
||||
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
||||
before picking one token using the modified logits
|
||||
'''
|
||||
|
||||
# Lame to have these here instead of modeling/warpers.py but JAX JIT stuff >:(
|
||||
# For documentation see modeling/warpers.py
|
||||
def sample_top_k(scores: jnp.array) -> jnp.array:
|
||||
sorted_indices_to_remove = jnp.arange(len(scores)) >= top_k
|
||||
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(-scores),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||
|
||||
|
||||
def sample_top_a(scores: jnp.array) -> jnp.array:
|
||||
probabilities = jax.nn.softmax(scores)
|
||||
probs_max = probabilities.max()
|
||||
return jnp.where(
|
||||
probabilities < probs_max * probs_max * top_a, -jnp.inf, scores
|
||||
)
|
||||
|
||||
|
||||
def sample_top_p(scores: jnp.array) -> jnp.array:
|
||||
sorted_logits = -jnp.sort(-scores)
|
||||
probabilities = jax.nn.softmax(sorted_logits)
|
||||
|
||||
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1)
|
||||
sorted_indices_to_remove = cumulative_probabilities > top_p
|
||||
|
||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(-scores),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||
|
||||
|
||||
def sample_tail_free(scores: jnp.array) -> jnp.array:
|
||||
sorted_logits = -jnp.sort(-scores)
|
||||
probabilities = jax.nn.softmax(sorted_logits)
|
||||
|
||||
d2 = jnp.diff(jnp.diff(probabilities))
|
||||
d2 = jnp.abs(d2)
|
||||
d2 = d2 / d2.sum(axis=-1, keepdims=True)
|
||||
|
||||
cumulative_d2 = jnp.cumsum(d2, axis=-1)
|
||||
sorted_indices_to_remove = cumulative_d2 > tfs
|
||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||
sorted_indices_to_remove = jnp.pad(
|
||||
sorted_indices_to_remove,
|
||||
(0, 2),
|
||||
constant_values=True,
|
||||
)
|
||||
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(-scores),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||
|
||||
|
||||
def sample_typical(scores: jnp.array) -> jnp.array:
|
||||
probs = jax.nn.softmax(scores)
|
||||
log_probs = jnp.log(probs)
|
||||
|
||||
neg_entropy = jnp.nansum(probs * log_probs, axis=-1, keepdims=True)
|
||||
entropy_deviation = jnp.abs(neg_entropy - log_probs)
|
||||
|
||||
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
|
||||
sorted_indices_to_remove = jnp.cumsum(sorted_logits, axis=-1) >= typical
|
||||
sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=-1)
|
||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(entropy_deviation),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||
|
||||
|
||||
def sample_temperature(scores: jnp.array) -> jnp.array:
|
||||
return scores / temp
|
||||
|
||||
|
||||
def sample_repetition_penalty(
|
||||
logits: jnp.array,
|
||||
tokens: jnp.array,
|
||||
repetition_penalty,
|
||||
generated_index,
|
||||
rpslope,
|
||||
rprange
|
||||
) -> jnp.array:
|
||||
"""
|
||||
This gets called to apply repetition penalty to the 1D array logits
|
||||
using the provided 1D array of tokens to penalize
|
||||
"""
|
||||
rpslope = jnp.int32(rpslope)
|
||||
rprange = jnp.int32(rprange)
|
||||
|
||||
clipped_rprange = jax.lax.cond(
|
||||
rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange
|
||||
)
|
||||
|
||||
penalty_arange = jnp.roll(
|
||||
jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]),
|
||||
generated_index,
|
||||
axis=-1,
|
||||
)
|
||||
# Make a new array with the same length as the tokens array but with
|
||||
# each element replaced by the value at the corresponding index in the
|
||||
# logits array; e.g.
|
||||
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
|
||||
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
|
||||
penalty_logits = jnp.take(logits, tokens)
|
||||
|
||||
# Repetition penalty slope
|
||||
def apply_slope(carry):
|
||||
repetition_penalty, rprange = carry
|
||||
_penalty = (penalty_arange / (rprange - 1)) * 2 - 1
|
||||
_penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1))
|
||||
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
|
||||
return _penalty
|
||||
|
||||
repetition_penalty = jax.lax.cond(
|
||||
(rpslope != 0.0)
|
||||
& (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash
|
||||
apply_slope,
|
||||
lambda carry: jnp.full(tokens.shape, carry[0]),
|
||||
(repetition_penalty, rprange),
|
||||
)
|
||||
# Divide positive values by repetition_penalty and multiply negative
|
||||
# values by repetition_penalty (the academic publication that described
|
||||
# this technique actually just only divided, but that would cause tokens
|
||||
# with negative logits to become more likely, which is obviously wrong)
|
||||
if koboldai_vars.use_alt_rep_pen:
|
||||
penalty_logits = jnp.where(
|
||||
penalty_arange >= 0,
|
||||
penalty_logits - jnp.log(repetition_penalty),
|
||||
penalty_logits,
|
||||
)
|
||||
else:
|
||||
penalty_logits = jnp.where(
|
||||
penalty_arange >= 0,
|
||||
jnp.where(
|
||||
penalty_logits > 0,
|
||||
penalty_logits / repetition_penalty,
|
||||
penalty_logits * repetition_penalty,
|
||||
),
|
||||
penalty_logits,
|
||||
)
|
||||
# Finally, put those penalized logit values back into their original
|
||||
# positions in the logits array
|
||||
return logits.at[tokens].set(penalty_logits)
|
||||
|
||||
|
||||
for k in sampler_order:
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), warpers.TopK.jax_static, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), warpers.TopA.jax_static, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), warpers.TopP.jax_static, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), warpers.TailFree.jax_static, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), warpers.Typical.jax_static, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), warpers.Temperature.jax_static, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: warpers.RepetitionPenalty.jax_static(*x), lambda x: x[0], (logits, *rpargs))
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), sample_top_k, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), sample_top_a, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), sample_top_p, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), sample_tail_free, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), sample_typical, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), sample_temperature, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: sample_repetition_penalty(*x), lambda x: x[0], (logits, *rpargs))
|
||||
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)
|
||||
|
||||
pad_token_id = 50256
|
||||
@@ -357,11 +521,10 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||
logits,
|
||||
(
|
||||
generated,
|
||||
# repetition_penalty,
|
||||
repetition_penalty,
|
||||
generated_index,
|
||||
# gen_length,
|
||||
# rpslope,
|
||||
# rprange,
|
||||
rpslope,
|
||||
rprange,
|
||||
),
|
||||
**sampler_options,
|
||||
)
|
||||
|
Reference in New Issue
Block a user