mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Undo pretty code because I haven't cracked the jax enigma yet
This commit is contained in:
@@ -68,13 +68,16 @@ class Warper:
|
|||||||
All Warpers should be singletons defined in the warpers.py file.
|
All Warpers should be singletons defined in the warpers.py file.
|
||||||
|
|
||||||
To make a new warper/sampler:
|
To make a new warper/sampler:
|
||||||
- Create your class, implementing `torch()`, `jax_dynamic`, `jax_static`,
|
- Create your class, implementing `torch()`, `jax_dynamic`, and
|
||||||
and `value_is_valid()`. Dynamic and static methods are seperated for Jax
|
`value_is_valid()`. Dynamic and static methods are seperated for Jax
|
||||||
due to how it does JIT compilation of functions (from what I gather).
|
due to how it does JIT compilation of functions (from what I gather).
|
||||||
These `static` methods are very picky about what you can and can't do
|
These static methods are very picky about what you can and can't do
|
||||||
with data at runtime and thus sometimes need to be implemented
|
with data at runtime and thus sometimes need to be implemented
|
||||||
differently than the `dynamic` methods, which are more like the Torch
|
differently than the `dynamic` methods, which are more like the Torch
|
||||||
methods.
|
methods.
|
||||||
|
- Implement your TPU static function (if applicable) in
|
||||||
|
tpu_mtj_backend.py->kobold_sample_static. If you do not do this, your
|
||||||
|
sampler will only work on the much slower dynamic generation mode on TPU.
|
||||||
- Add it to Warper.from_id and tpu_mtj_backend.kobold_sample_static.
|
- Add it to Warper.from_id and tpu_mtj_backend.kobold_sample_static.
|
||||||
- Add it to the UI/sampler_order.
|
- Add it to the UI/sampler_order.
|
||||||
|
|
||||||
@@ -103,10 +106,6 @@ class Warper:
|
|||||||
def jax_dynamic(cls, scores: np.array) -> np.array:
|
def jax_dynamic(cls, scores: np.array) -> np.array:
|
||||||
raise NotImplementedError("Please override `jax_dynamic()`.")
|
raise NotImplementedError("Please override `jax_dynamic()`.")
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
|
||||||
raise NotImplementedError("Please override `jax_static()`.")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_is_valid(cls) -> bool:
|
def value_is_valid(cls) -> bool:
|
||||||
raise NotImplementedError("Please override `value_is_valid()`.")
|
raise NotImplementedError("Please override `value_is_valid()`.")
|
||||||
@@ -125,10 +124,6 @@ class Temperature(Warper):
|
|||||||
def jax_dynamic(cls, scores: np.array) -> np.array:
|
def jax_dynamic(cls, scores: np.array) -> np.array:
|
||||||
return scores / cls.temperature
|
return scores / cls.temperature
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
|
||||||
return scores / cls.temperature
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_is_valid(cls) -> bool:
|
def value_is_valid(cls) -> bool:
|
||||||
return cls.temperature != 1.0
|
return cls.temperature != 1.0
|
||||||
@@ -181,30 +176,6 @@ class TopP(Warper):
|
|||||||
)
|
)
|
||||||
return np.where(indices_to_remove, -np.inf, scores)
|
return np.where(indices_to_remove, -np.inf, scores)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
|
||||||
# Sort the logits array in descending order, replace every element
|
|
||||||
# with e (Euler's number) to the power of that element, and divide
|
|
||||||
# each element of the new array by the sum of the elements in the
|
|
||||||
# new array
|
|
||||||
sorted_logits = -jnp.sort(-scores)
|
|
||||||
probabilities = jax.nn.softmax(sorted_logits)
|
|
||||||
# Calculate cumulative_probabilities as the prefix-sum array of
|
|
||||||
# probabilities
|
|
||||||
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1)
|
|
||||||
# We want to remove tokens with cumulative probability higher
|
|
||||||
# than top_p
|
|
||||||
sorted_indices_to_remove = cumulative_probabilities > cls.top_p
|
|
||||||
# Don't ever remove the token with the highest logit, even if
|
|
||||||
# the probability is higher than top_p
|
|
||||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
|
||||||
# Unsort and remove
|
|
||||||
_, indices_to_remove = jax.lax.sort_key_val(
|
|
||||||
jnp.argsort(-scores),
|
|
||||||
sorted_indices_to_remove,
|
|
||||||
)
|
|
||||||
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_is_valid(cls) -> bool:
|
def value_is_valid(cls) -> bool:
|
||||||
return cls.top_p < 1.0
|
return cls.top_p < 1.0
|
||||||
@@ -242,16 +213,6 @@ class TopK(Warper):
|
|||||||
)
|
)
|
||||||
return np.where(indices_to_remove, -np.inf, scores)
|
return np.where(indices_to_remove, -np.inf, scores)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
|
||||||
sorted_indices_to_remove = jnp.arange(len(scores)) >= cls.top_k
|
|
||||||
|
|
||||||
_, indices_to_remove = jax.lax.sort_key_val(
|
|
||||||
jnp.argsort(-scores),
|
|
||||||
sorted_indices_to_remove,
|
|
||||||
)
|
|
||||||
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_is_valid(cls) -> bool:
|
def value_is_valid(cls) -> bool:
|
||||||
return cls.top_k > 0
|
return cls.top_k > 0
|
||||||
@@ -339,30 +300,6 @@ class TailFree(Warper):
|
|||||||
)
|
)
|
||||||
return np.where(indices_to_remove, -np.inf, scores)
|
return np.where(indices_to_remove, -np.inf, scores)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
|
||||||
sorted_logits = -jnp.sort(-scores)
|
|
||||||
probabilities = jax.nn.softmax(sorted_logits)
|
|
||||||
|
|
||||||
d2 = jnp.diff(jnp.diff(probabilities))
|
|
||||||
d2 = jnp.abs(d2)
|
|
||||||
d2 = d2 / d2.sum(axis=-1, keepdims=True)
|
|
||||||
|
|
||||||
cumulative_d2 = jnp.cumsum(d2, axis=-1)
|
|
||||||
sorted_indices_to_remove = cumulative_d2 > cls.tfs
|
|
||||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
|
||||||
sorted_indices_to_remove = jnp.pad(
|
|
||||||
sorted_indices_to_remove,
|
|
||||||
(0, 2),
|
|
||||||
constant_values=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
_, indices_to_remove = jax.lax.sort_key_val(
|
|
||||||
jnp.argsort(-scores),
|
|
||||||
sorted_indices_to_remove,
|
|
||||||
)
|
|
||||||
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_is_valid(cls) -> bool:
|
def value_is_valid(cls) -> bool:
|
||||||
return cls.tfs < 1.0
|
return cls.tfs < 1.0
|
||||||
@@ -443,25 +380,6 @@ class Typical(Warper):
|
|||||||
)
|
)
|
||||||
return np.where(indices_to_remove, -jnp.inf, scores)
|
return np.where(indices_to_remove, -jnp.inf, scores)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
|
||||||
probs = jax.nn.softmax(scores)
|
|
||||||
log_probs = jnp.log(probs)
|
|
||||||
|
|
||||||
neg_entropy = jnp.nansum(probs * log_probs, axis=-1, keepdims=True)
|
|
||||||
entropy_deviation = jnp.abs(neg_entropy - log_probs)
|
|
||||||
|
|
||||||
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
|
|
||||||
sorted_indices_to_remove = jnp.cumsum(sorted_logits, axis=-1) >= cls.typical
|
|
||||||
sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=-1)
|
|
||||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
|
||||||
|
|
||||||
_, indices_to_remove = jax.lax.sort_key_val(
|
|
||||||
jnp.argsort(entropy_deviation),
|
|
||||||
sorted_indices_to_remove,
|
|
||||||
)
|
|
||||||
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_is_valid(cls) -> bool:
|
def value_is_valid(cls) -> bool:
|
||||||
return cls.typical < 1.0
|
return cls.typical < 1.0
|
||||||
@@ -504,14 +422,6 @@ class TopA(Warper):
|
|||||||
probabilities < probs_max * probs_max * cls.top_a, -np.inf, scores
|
probabilities < probs_max * probs_max * cls.top_a, -np.inf, scores
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
|
||||||
probabilities = jax.nn.softmax(scores)
|
|
||||||
probs_max = probabilities.max()
|
|
||||||
return jnp.where(
|
|
||||||
probabilities < probs_max * probs_max * cls.top_a, -jnp.inf, scores
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_is_valid(cls) -> bool:
|
def value_is_valid(cls) -> bool:
|
||||||
return cls.top_a > 0.0
|
return cls.top_a > 0.0
|
||||||
@@ -557,75 +467,6 @@ class RepetitionPenalty(Warper):
|
|||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def jax_static(
|
|
||||||
cls,
|
|
||||||
logits: jnp.array,
|
|
||||||
tokens: jnp.array,
|
|
||||||
generated_index,
|
|
||||||
) -> jnp.array:
|
|
||||||
"""
|
|
||||||
This gets called to apply repetition penalty to the 1D array logits
|
|
||||||
using the provided 1D array of tokens to penalize
|
|
||||||
"""
|
|
||||||
rpslope = jnp.int32(cls.rep_pen_slope)
|
|
||||||
rprange = jnp.int32(cls.rep_pen_range)
|
|
||||||
repetition_penalty = cls.rep_pen
|
|
||||||
|
|
||||||
clipped_rprange = jax.lax.cond(
|
|
||||||
rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange
|
|
||||||
)
|
|
||||||
|
|
||||||
penalty_arange = jnp.roll(
|
|
||||||
jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]),
|
|
||||||
generated_index,
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
# Make a new array with the same length as the tokens array but with
|
|
||||||
# each element replaced by the value at the corresponding index in the
|
|
||||||
# logits array; e.g.
|
|
||||||
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
|
|
||||||
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
|
|
||||||
penalty_logits = jnp.take(logits, tokens)
|
|
||||||
# Repetition penalty slope
|
|
||||||
def apply_slope(carry):
|
|
||||||
repetition_penalty, rprange = carry
|
|
||||||
_penalty = (penalty_arange / (rprange - 1)) * 2 - 1
|
|
||||||
_penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1))
|
|
||||||
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
|
|
||||||
return _penalty
|
|
||||||
|
|
||||||
repetition_penalty = jax.lax.cond(
|
|
||||||
(rpslope != 0.0)
|
|
||||||
& (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash
|
|
||||||
apply_slope,
|
|
||||||
lambda carry: jnp.full(tokens.shape, carry[0]),
|
|
||||||
(repetition_penalty, rprange),
|
|
||||||
)
|
|
||||||
# Divide positive values by repetition_penalty and multiply negative
|
|
||||||
# values by repetition_penalty (the academic publication that described
|
|
||||||
# this technique actually just only divided, but that would cause tokens
|
|
||||||
# with negative logits to become more likely, which is obviously wrong)
|
|
||||||
if cls.use_alt_rep_pen:
|
|
||||||
penalty_logits = jnp.where(
|
|
||||||
penalty_arange >= 0,
|
|
||||||
penalty_logits - jnp.log(repetition_penalty),
|
|
||||||
penalty_logits,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
penalty_logits = jnp.where(
|
|
||||||
penalty_arange >= 0,
|
|
||||||
jnp.where(
|
|
||||||
penalty_logits > 0,
|
|
||||||
penalty_logits / repetition_penalty,
|
|
||||||
penalty_logits * repetition_penalty,
|
|
||||||
),
|
|
||||||
penalty_logits,
|
|
||||||
)
|
|
||||||
# Finally, put those penalized logit values back into their original
|
|
||||||
# positions in the logits array
|
|
||||||
return logits.at[tokens].set(penalty_logits)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def jax_dynamic(
|
def jax_dynamic(
|
||||||
cls,
|
cls,
|
||||||
|
@@ -226,21 +226,185 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra
|
|||||||
# probability distribution)
|
# probability distribution)
|
||||||
return jax.random.categorical(key, logits, -1).astype(np.uint32)
|
return jax.random.categorical(key, logits, -1).astype(np.uint32)
|
||||||
|
|
||||||
|
def kobold_sample_static(
|
||||||
def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
key,
|
||||||
|
logits,
|
||||||
|
rpargs,
|
||||||
|
sampler_order: Optional[np.ndarray] = None,
|
||||||
|
top_p=0.9,
|
||||||
|
temp=0.5,
|
||||||
|
top_k=0,
|
||||||
|
tfs=1.0,
|
||||||
|
typical=1.0,
|
||||||
|
top_a=0.0,
|
||||||
|
):
|
||||||
'''
|
'''
|
||||||
This gets called by generate_loop_fn to apply a series of 6 filters
|
This gets called by generate_loop_fn to apply a series of 6 filters
|
||||||
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
||||||
before picking one token using the modified logits
|
before picking one token using the modified logits
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
# Lame to have these here instead of modeling/warpers.py but JAX JIT stuff >:(
|
||||||
|
# For documentation see modeling/warpers.py
|
||||||
|
def sample_top_k(scores: jnp.array) -> jnp.array:
|
||||||
|
sorted_indices_to_remove = jnp.arange(len(scores)) >= top_k
|
||||||
|
|
||||||
|
_, indices_to_remove = jax.lax.sort_key_val(
|
||||||
|
jnp.argsort(-scores),
|
||||||
|
sorted_indices_to_remove,
|
||||||
|
)
|
||||||
|
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_top_a(scores: jnp.array) -> jnp.array:
|
||||||
|
probabilities = jax.nn.softmax(scores)
|
||||||
|
probs_max = probabilities.max()
|
||||||
|
return jnp.where(
|
||||||
|
probabilities < probs_max * probs_max * top_a, -jnp.inf, scores
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_top_p(scores: jnp.array) -> jnp.array:
|
||||||
|
sorted_logits = -jnp.sort(-scores)
|
||||||
|
probabilities = jax.nn.softmax(sorted_logits)
|
||||||
|
|
||||||
|
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1)
|
||||||
|
sorted_indices_to_remove = cumulative_probabilities > top_p
|
||||||
|
|
||||||
|
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||||
|
_, indices_to_remove = jax.lax.sort_key_val(
|
||||||
|
jnp.argsort(-scores),
|
||||||
|
sorted_indices_to_remove,
|
||||||
|
)
|
||||||
|
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_tail_free(scores: jnp.array) -> jnp.array:
|
||||||
|
sorted_logits = -jnp.sort(-scores)
|
||||||
|
probabilities = jax.nn.softmax(sorted_logits)
|
||||||
|
|
||||||
|
d2 = jnp.diff(jnp.diff(probabilities))
|
||||||
|
d2 = jnp.abs(d2)
|
||||||
|
d2 = d2 / d2.sum(axis=-1, keepdims=True)
|
||||||
|
|
||||||
|
cumulative_d2 = jnp.cumsum(d2, axis=-1)
|
||||||
|
sorted_indices_to_remove = cumulative_d2 > tfs
|
||||||
|
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||||
|
sorted_indices_to_remove = jnp.pad(
|
||||||
|
sorted_indices_to_remove,
|
||||||
|
(0, 2),
|
||||||
|
constant_values=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
_, indices_to_remove = jax.lax.sort_key_val(
|
||||||
|
jnp.argsort(-scores),
|
||||||
|
sorted_indices_to_remove,
|
||||||
|
)
|
||||||
|
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_typical(scores: jnp.array) -> jnp.array:
|
||||||
|
probs = jax.nn.softmax(scores)
|
||||||
|
log_probs = jnp.log(probs)
|
||||||
|
|
||||||
|
neg_entropy = jnp.nansum(probs * log_probs, axis=-1, keepdims=True)
|
||||||
|
entropy_deviation = jnp.abs(neg_entropy - log_probs)
|
||||||
|
|
||||||
|
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
|
||||||
|
sorted_indices_to_remove = jnp.cumsum(sorted_logits, axis=-1) >= typical
|
||||||
|
sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=-1)
|
||||||
|
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||||
|
|
||||||
|
_, indices_to_remove = jax.lax.sort_key_val(
|
||||||
|
jnp.argsort(entropy_deviation),
|
||||||
|
sorted_indices_to_remove,
|
||||||
|
)
|
||||||
|
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_temperature(scores: jnp.array) -> jnp.array:
|
||||||
|
return scores / temp
|
||||||
|
|
||||||
|
|
||||||
|
def sample_repetition_penalty(
|
||||||
|
logits: jnp.array,
|
||||||
|
tokens: jnp.array,
|
||||||
|
repetition_penalty,
|
||||||
|
generated_index,
|
||||||
|
rpslope,
|
||||||
|
rprange
|
||||||
|
) -> jnp.array:
|
||||||
|
"""
|
||||||
|
This gets called to apply repetition penalty to the 1D array logits
|
||||||
|
using the provided 1D array of tokens to penalize
|
||||||
|
"""
|
||||||
|
rpslope = jnp.int32(rpslope)
|
||||||
|
rprange = jnp.int32(rprange)
|
||||||
|
|
||||||
|
clipped_rprange = jax.lax.cond(
|
||||||
|
rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange
|
||||||
|
)
|
||||||
|
|
||||||
|
penalty_arange = jnp.roll(
|
||||||
|
jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]),
|
||||||
|
generated_index,
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
# Make a new array with the same length as the tokens array but with
|
||||||
|
# each element replaced by the value at the corresponding index in the
|
||||||
|
# logits array; e.g.
|
||||||
|
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
|
||||||
|
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
|
||||||
|
penalty_logits = jnp.take(logits, tokens)
|
||||||
|
|
||||||
|
# Repetition penalty slope
|
||||||
|
def apply_slope(carry):
|
||||||
|
repetition_penalty, rprange = carry
|
||||||
|
_penalty = (penalty_arange / (rprange - 1)) * 2 - 1
|
||||||
|
_penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1))
|
||||||
|
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
|
||||||
|
return _penalty
|
||||||
|
|
||||||
|
repetition_penalty = jax.lax.cond(
|
||||||
|
(rpslope != 0.0)
|
||||||
|
& (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash
|
||||||
|
apply_slope,
|
||||||
|
lambda carry: jnp.full(tokens.shape, carry[0]),
|
||||||
|
(repetition_penalty, rprange),
|
||||||
|
)
|
||||||
|
# Divide positive values by repetition_penalty and multiply negative
|
||||||
|
# values by repetition_penalty (the academic publication that described
|
||||||
|
# this technique actually just only divided, but that would cause tokens
|
||||||
|
# with negative logits to become more likely, which is obviously wrong)
|
||||||
|
if koboldai_vars.use_alt_rep_pen:
|
||||||
|
penalty_logits = jnp.where(
|
||||||
|
penalty_arange >= 0,
|
||||||
|
penalty_logits - jnp.log(repetition_penalty),
|
||||||
|
penalty_logits,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
penalty_logits = jnp.where(
|
||||||
|
penalty_arange >= 0,
|
||||||
|
jnp.where(
|
||||||
|
penalty_logits > 0,
|
||||||
|
penalty_logits / repetition_penalty,
|
||||||
|
penalty_logits * repetition_penalty,
|
||||||
|
),
|
||||||
|
penalty_logits,
|
||||||
|
)
|
||||||
|
# Finally, put those penalized logit values back into their original
|
||||||
|
# positions in the logits array
|
||||||
|
return logits.at[tokens].set(penalty_logits)
|
||||||
|
|
||||||
|
|
||||||
for k in sampler_order:
|
for k in sampler_order:
|
||||||
logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), warpers.TopK.jax_static, lambda x: x, logits)
|
logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), sample_top_k, lambda x: x, logits)
|
||||||
logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), warpers.TopA.jax_static, lambda x: x, logits)
|
logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), sample_top_a, lambda x: x, logits)
|
||||||
logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), warpers.TopP.jax_static, lambda x: x, logits)
|
logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), sample_top_p, lambda x: x, logits)
|
||||||
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), warpers.TailFree.jax_static, lambda x: x, logits)
|
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), sample_tail_free, lambda x: x, logits)
|
||||||
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), warpers.Typical.jax_static, lambda x: x, logits)
|
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), sample_typical, lambda x: x, logits)
|
||||||
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), warpers.Temperature.jax_static, lambda x: x, logits)
|
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), sample_temperature, lambda x: x, logits)
|
||||||
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: warpers.RepetitionPenalty.jax_static(*x), lambda x: x[0], (logits, *rpargs))
|
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: sample_repetition_penalty(*x), lambda x: x[0], (logits, *rpargs))
|
||||||
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)
|
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)
|
||||||
|
|
||||||
pad_token_id = 50256
|
pad_token_id = 50256
|
||||||
@@ -357,11 +521,10 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||||||
logits,
|
logits,
|
||||||
(
|
(
|
||||||
generated,
|
generated,
|
||||||
# repetition_penalty,
|
repetition_penalty,
|
||||||
generated_index,
|
generated_index,
|
||||||
# gen_length,
|
rpslope,
|
||||||
# rpslope,
|
rprange,
|
||||||
# rprange,
|
|
||||||
),
|
),
|
||||||
**sampler_options,
|
**sampler_options,
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user