mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Undo pretty code because I haven't cracked the jax enigma yet
This commit is contained in:
@@ -68,13 +68,16 @@ class Warper:
|
||||
All Warpers should be singletons defined in the warpers.py file.
|
||||
|
||||
To make a new warper/sampler:
|
||||
- Create your class, implementing `torch()`, `jax_dynamic`, `jax_static`,
|
||||
and `value_is_valid()`. Dynamic and static methods are seperated for Jax
|
||||
- Create your class, implementing `torch()`, `jax_dynamic`, and
|
||||
`value_is_valid()`. Dynamic and static methods are seperated for Jax
|
||||
due to how it does JIT compilation of functions (from what I gather).
|
||||
These `static` methods are very picky about what you can and can't do
|
||||
These static methods are very picky about what you can and can't do
|
||||
with data at runtime and thus sometimes need to be implemented
|
||||
differently than the `dynamic` methods, which are more like the Torch
|
||||
methods.
|
||||
- Implement your TPU static function (if applicable) in
|
||||
tpu_mtj_backend.py->kobold_sample_static. If you do not do this, your
|
||||
sampler will only work on the much slower dynamic generation mode on TPU.
|
||||
- Add it to Warper.from_id and tpu_mtj_backend.kobold_sample_static.
|
||||
- Add it to the UI/sampler_order.
|
||||
|
||||
@@ -103,10 +106,6 @@ class Warper:
|
||||
def jax_dynamic(cls, scores: np.array) -> np.array:
|
||||
raise NotImplementedError("Please override `jax_dynamic()`.")
|
||||
|
||||
@classmethod
|
||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
raise NotImplementedError("Please override `jax_static()`.")
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
raise NotImplementedError("Please override `value_is_valid()`.")
|
||||
@@ -125,10 +124,6 @@ class Temperature(Warper):
|
||||
def jax_dynamic(cls, scores: np.array) -> np.array:
|
||||
return scores / cls.temperature
|
||||
|
||||
@classmethod
|
||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
return scores / cls.temperature
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.temperature != 1.0
|
||||
@@ -181,30 +176,6 @@ class TopP(Warper):
|
||||
)
|
||||
return np.where(indices_to_remove, -np.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
# Sort the logits array in descending order, replace every element
|
||||
# with e (Euler's number) to the power of that element, and divide
|
||||
# each element of the new array by the sum of the elements in the
|
||||
# new array
|
||||
sorted_logits = -jnp.sort(-scores)
|
||||
probabilities = jax.nn.softmax(sorted_logits)
|
||||
# Calculate cumulative_probabilities as the prefix-sum array of
|
||||
# probabilities
|
||||
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1)
|
||||
# We want to remove tokens with cumulative probability higher
|
||||
# than top_p
|
||||
sorted_indices_to_remove = cumulative_probabilities > cls.top_p
|
||||
# Don't ever remove the token with the highest logit, even if
|
||||
# the probability is higher than top_p
|
||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||
# Unsort and remove
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(-scores),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.top_p < 1.0
|
||||
@@ -242,16 +213,6 @@ class TopK(Warper):
|
||||
)
|
||||
return np.where(indices_to_remove, -np.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
sorted_indices_to_remove = jnp.arange(len(scores)) >= cls.top_k
|
||||
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(-scores),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.top_k > 0
|
||||
@@ -339,30 +300,6 @@ class TailFree(Warper):
|
||||
)
|
||||
return np.where(indices_to_remove, -np.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
sorted_logits = -jnp.sort(-scores)
|
||||
probabilities = jax.nn.softmax(sorted_logits)
|
||||
|
||||
d2 = jnp.diff(jnp.diff(probabilities))
|
||||
d2 = jnp.abs(d2)
|
||||
d2 = d2 / d2.sum(axis=-1, keepdims=True)
|
||||
|
||||
cumulative_d2 = jnp.cumsum(d2, axis=-1)
|
||||
sorted_indices_to_remove = cumulative_d2 > cls.tfs
|
||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||
sorted_indices_to_remove = jnp.pad(
|
||||
sorted_indices_to_remove,
|
||||
(0, 2),
|
||||
constant_values=True,
|
||||
)
|
||||
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(-scores),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.tfs < 1.0
|
||||
@@ -443,25 +380,6 @@ class Typical(Warper):
|
||||
)
|
||||
return np.where(indices_to_remove, -jnp.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
probs = jax.nn.softmax(scores)
|
||||
log_probs = jnp.log(probs)
|
||||
|
||||
neg_entropy = jnp.nansum(probs * log_probs, axis=-1, keepdims=True)
|
||||
entropy_deviation = jnp.abs(neg_entropy - log_probs)
|
||||
|
||||
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
|
||||
sorted_indices_to_remove = jnp.cumsum(sorted_logits, axis=-1) >= cls.typical
|
||||
sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=-1)
|
||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(entropy_deviation),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.typical < 1.0
|
||||
@@ -504,14 +422,6 @@ class TopA(Warper):
|
||||
probabilities < probs_max * probs_max * cls.top_a, -np.inf, scores
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
probabilities = jax.nn.softmax(scores)
|
||||
probs_max = probabilities.max()
|
||||
return jnp.where(
|
||||
probabilities < probs_max * probs_max * cls.top_a, -jnp.inf, scores
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.top_a > 0.0
|
||||
@@ -557,75 +467,6 @@ class RepetitionPenalty(Warper):
|
||||
|
||||
return scores
|
||||
|
||||
@classmethod
|
||||
def jax_static(
|
||||
cls,
|
||||
logits: jnp.array,
|
||||
tokens: jnp.array,
|
||||
generated_index,
|
||||
) -> jnp.array:
|
||||
"""
|
||||
This gets called to apply repetition penalty to the 1D array logits
|
||||
using the provided 1D array of tokens to penalize
|
||||
"""
|
||||
rpslope = jnp.int32(cls.rep_pen_slope)
|
||||
rprange = jnp.int32(cls.rep_pen_range)
|
||||
repetition_penalty = cls.rep_pen
|
||||
|
||||
clipped_rprange = jax.lax.cond(
|
||||
rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange
|
||||
)
|
||||
|
||||
penalty_arange = jnp.roll(
|
||||
jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]),
|
||||
generated_index,
|
||||
axis=-1,
|
||||
)
|
||||
# Make a new array with the same length as the tokens array but with
|
||||
# each element replaced by the value at the corresponding index in the
|
||||
# logits array; e.g.
|
||||
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
|
||||
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
|
||||
penalty_logits = jnp.take(logits, tokens)
|
||||
# Repetition penalty slope
|
||||
def apply_slope(carry):
|
||||
repetition_penalty, rprange = carry
|
||||
_penalty = (penalty_arange / (rprange - 1)) * 2 - 1
|
||||
_penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1))
|
||||
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
|
||||
return _penalty
|
||||
|
||||
repetition_penalty = jax.lax.cond(
|
||||
(rpslope != 0.0)
|
||||
& (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash
|
||||
apply_slope,
|
||||
lambda carry: jnp.full(tokens.shape, carry[0]),
|
||||
(repetition_penalty, rprange),
|
||||
)
|
||||
# Divide positive values by repetition_penalty and multiply negative
|
||||
# values by repetition_penalty (the academic publication that described
|
||||
# this technique actually just only divided, but that would cause tokens
|
||||
# with negative logits to become more likely, which is obviously wrong)
|
||||
if cls.use_alt_rep_pen:
|
||||
penalty_logits = jnp.where(
|
||||
penalty_arange >= 0,
|
||||
penalty_logits - jnp.log(repetition_penalty),
|
||||
penalty_logits,
|
||||
)
|
||||
else:
|
||||
penalty_logits = jnp.where(
|
||||
penalty_arange >= 0,
|
||||
jnp.where(
|
||||
penalty_logits > 0,
|
||||
penalty_logits / repetition_penalty,
|
||||
penalty_logits * repetition_penalty,
|
||||
),
|
||||
penalty_logits,
|
||||
)
|
||||
# Finally, put those penalized logit values back into their original
|
||||
# positions in the logits array
|
||||
return logits.at[tokens].set(penalty_logits)
|
||||
|
||||
@classmethod
|
||||
def jax_dynamic(
|
||||
cls,
|
||||
|
Reference in New Issue
Block a user