From d496e861f4ab3f6c05e4d92bc9e29e84511cd87d Mon Sep 17 00:00:00 2001 From: onesome Date: Tue, 25 Apr 2023 19:54:58 -0500 Subject: [PATCH] Undo pretty code because I haven't cracked the jax enigma yet --- modeling/warpers.py | 171 ++------------------------------------- tpu_mtj_backend.py | 189 +++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 182 insertions(+), 178 deletions(-) diff --git a/modeling/warpers.py b/modeling/warpers.py index 4c7dbac4..ca7e7396 100644 --- a/modeling/warpers.py +++ b/modeling/warpers.py @@ -68,13 +68,16 @@ class Warper: All Warpers should be singletons defined in the warpers.py file. To make a new warper/sampler: - - Create your class, implementing `torch()`, `jax_dynamic`, `jax_static`, - and `value_is_valid()`. Dynamic and static methods are seperated for Jax + - Create your class, implementing `torch()`, `jax_dynamic`, and + `value_is_valid()`. Dynamic and static methods are seperated for Jax due to how it does JIT compilation of functions (from what I gather). - These `static` methods are very picky about what you can and can't do + These static methods are very picky about what you can and can't do with data at runtime and thus sometimes need to be implemented differently than the `dynamic` methods, which are more like the Torch methods. + - Implement your TPU static function (if applicable) in + tpu_mtj_backend.py->kobold_sample_static. If you do not do this, your + sampler will only work on the much slower dynamic generation mode on TPU. - Add it to Warper.from_id and tpu_mtj_backend.kobold_sample_static. - Add it to the UI/sampler_order. @@ -103,10 +106,6 @@ class Warper: def jax_dynamic(cls, scores: np.array) -> np.array: raise NotImplementedError("Please override `jax_dynamic()`.") - @classmethod - def jax_static(cls, scores: jnp.array) -> jnp.array: - raise NotImplementedError("Please override `jax_static()`.") - @classmethod def value_is_valid(cls) -> bool: raise NotImplementedError("Please override `value_is_valid()`.") @@ -125,10 +124,6 @@ class Temperature(Warper): def jax_dynamic(cls, scores: np.array) -> np.array: return scores / cls.temperature - @classmethod - def jax_static(cls, scores: jnp.array) -> jnp.array: - return scores / cls.temperature - @classmethod def value_is_valid(cls) -> bool: return cls.temperature != 1.0 @@ -181,30 +176,6 @@ class TopP(Warper): ) return np.where(indices_to_remove, -np.inf, scores) - @classmethod - def jax_static(cls, scores: jnp.array) -> jnp.array: - # Sort the logits array in descending order, replace every element - # with e (Euler's number) to the power of that element, and divide - # each element of the new array by the sum of the elements in the - # new array - sorted_logits = -jnp.sort(-scores) - probabilities = jax.nn.softmax(sorted_logits) - # Calculate cumulative_probabilities as the prefix-sum array of - # probabilities - cumulative_probabilities = jnp.cumsum(probabilities, axis=-1) - # We want to remove tokens with cumulative probability higher - # than top_p - sorted_indices_to_remove = cumulative_probabilities > cls.top_p - # Don't ever remove the token with the highest logit, even if - # the probability is higher than top_p - sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) - # Unsort and remove - _, indices_to_remove = jax.lax.sort_key_val( - jnp.argsort(-scores), - sorted_indices_to_remove, - ) - return jnp.where(indices_to_remove, -jnp.inf, scores) - @classmethod def value_is_valid(cls) -> bool: return cls.top_p < 1.0 @@ -242,16 +213,6 @@ class TopK(Warper): ) return np.where(indices_to_remove, -np.inf, scores) - @classmethod - def jax_static(cls, scores: jnp.array) -> jnp.array: - sorted_indices_to_remove = jnp.arange(len(scores)) >= cls.top_k - - _, indices_to_remove = jax.lax.sort_key_val( - jnp.argsort(-scores), - sorted_indices_to_remove, - ) - return jnp.where(indices_to_remove, -jnp.inf, scores) - @classmethod def value_is_valid(cls) -> bool: return cls.top_k > 0 @@ -339,30 +300,6 @@ class TailFree(Warper): ) return np.where(indices_to_remove, -np.inf, scores) - @classmethod - def jax_static(cls, scores: jnp.array) -> jnp.array: - sorted_logits = -jnp.sort(-scores) - probabilities = jax.nn.softmax(sorted_logits) - - d2 = jnp.diff(jnp.diff(probabilities)) - d2 = jnp.abs(d2) - d2 = d2 / d2.sum(axis=-1, keepdims=True) - - cumulative_d2 = jnp.cumsum(d2, axis=-1) - sorted_indices_to_remove = cumulative_d2 > cls.tfs - sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) - sorted_indices_to_remove = jnp.pad( - sorted_indices_to_remove, - (0, 2), - constant_values=True, - ) - - _, indices_to_remove = jax.lax.sort_key_val( - jnp.argsort(-scores), - sorted_indices_to_remove, - ) - return jnp.where(indices_to_remove, -jnp.inf, scores) - @classmethod def value_is_valid(cls) -> bool: return cls.tfs < 1.0 @@ -443,25 +380,6 @@ class Typical(Warper): ) return np.where(indices_to_remove, -jnp.inf, scores) - @classmethod - def jax_static(cls, scores: jnp.array) -> jnp.array: - probs = jax.nn.softmax(scores) - log_probs = jnp.log(probs) - - neg_entropy = jnp.nansum(probs * log_probs, axis=-1, keepdims=True) - entropy_deviation = jnp.abs(neg_entropy - log_probs) - - _, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs) - sorted_indices_to_remove = jnp.cumsum(sorted_logits, axis=-1) >= cls.typical - sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=-1) - sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) - - _, indices_to_remove = jax.lax.sort_key_val( - jnp.argsort(entropy_deviation), - sorted_indices_to_remove, - ) - return jnp.where(indices_to_remove, -jnp.inf, scores) - @classmethod def value_is_valid(cls) -> bool: return cls.typical < 1.0 @@ -504,14 +422,6 @@ class TopA(Warper): probabilities < probs_max * probs_max * cls.top_a, -np.inf, scores ) - @classmethod - def jax_static(cls, scores: jnp.array) -> jnp.array: - probabilities = jax.nn.softmax(scores) - probs_max = probabilities.max() - return jnp.where( - probabilities < probs_max * probs_max * cls.top_a, -jnp.inf, scores - ) - @classmethod def value_is_valid(cls) -> bool: return cls.top_a > 0.0 @@ -557,75 +467,6 @@ class RepetitionPenalty(Warper): return scores - @classmethod - def jax_static( - cls, - logits: jnp.array, - tokens: jnp.array, - generated_index, - ) -> jnp.array: - """ - This gets called to apply repetition penalty to the 1D array logits - using the provided 1D array of tokens to penalize - """ - rpslope = jnp.int32(cls.rep_pen_slope) - rprange = jnp.int32(cls.rep_pen_range) - repetition_penalty = cls.rep_pen - - clipped_rprange = jax.lax.cond( - rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange - ) - - penalty_arange = jnp.roll( - jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]), - generated_index, - axis=-1, - ) - # Make a new array with the same length as the tokens array but with - # each element replaced by the value at the corresponding index in the - # logits array; e.g. - # if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1], - # then penalty_logits will be [77, 5, 3, 98, 3, 98, 5] - penalty_logits = jnp.take(logits, tokens) - # Repetition penalty slope - def apply_slope(carry): - repetition_penalty, rprange = carry - _penalty = (penalty_arange / (rprange - 1)) * 2 - 1 - _penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1)) - _penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1) - return _penalty - - repetition_penalty = jax.lax.cond( - (rpslope != 0.0) - & (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash - apply_slope, - lambda carry: jnp.full(tokens.shape, carry[0]), - (repetition_penalty, rprange), - ) - # Divide positive values by repetition_penalty and multiply negative - # values by repetition_penalty (the academic publication that described - # this technique actually just only divided, but that would cause tokens - # with negative logits to become more likely, which is obviously wrong) - if cls.use_alt_rep_pen: - penalty_logits = jnp.where( - penalty_arange >= 0, - penalty_logits - jnp.log(repetition_penalty), - penalty_logits, - ) - else: - penalty_logits = jnp.where( - penalty_arange >= 0, - jnp.where( - penalty_logits > 0, - penalty_logits / repetition_penalty, - penalty_logits * repetition_penalty, - ), - penalty_logits, - ) - # Finally, put those penalized logit values back into their original - # positions in the logits array - return logits.at[tokens].set(penalty_logits) - @classmethod def jax_dynamic( cls, diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 31e067eb..d2021882 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -226,21 +226,185 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra # probability distribution) return jax.random.categorical(key, logits, -1).astype(np.uint32) - -def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): +def kobold_sample_static( + key, + logits, + rpargs, + sampler_order: Optional[np.ndarray] = None, + top_p=0.9, + temp=0.5, + top_k=0, + tfs=1.0, + typical=1.0, + top_a=0.0, +): ''' This gets called by generate_loop_fn to apply a series of 6 filters to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) before picking one token using the modified logits ''' + + # Lame to have these here instead of modeling/warpers.py but JAX JIT stuff >:( + # For documentation see modeling/warpers.py + def sample_top_k(scores: jnp.array) -> jnp.array: + sorted_indices_to_remove = jnp.arange(len(scores)) >= top_k + + _, indices_to_remove = jax.lax.sort_key_val( + jnp.argsort(-scores), + sorted_indices_to_remove, + ) + return jnp.where(indices_to_remove, -jnp.inf, scores) + + + def sample_top_a(scores: jnp.array) -> jnp.array: + probabilities = jax.nn.softmax(scores) + probs_max = probabilities.max() + return jnp.where( + probabilities < probs_max * probs_max * top_a, -jnp.inf, scores + ) + + + def sample_top_p(scores: jnp.array) -> jnp.array: + sorted_logits = -jnp.sort(-scores) + probabilities = jax.nn.softmax(sorted_logits) + + cumulative_probabilities = jnp.cumsum(probabilities, axis=-1) + sorted_indices_to_remove = cumulative_probabilities > top_p + + sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) + _, indices_to_remove = jax.lax.sort_key_val( + jnp.argsort(-scores), + sorted_indices_to_remove, + ) + return jnp.where(indices_to_remove, -jnp.inf, scores) + + + def sample_tail_free(scores: jnp.array) -> jnp.array: + sorted_logits = -jnp.sort(-scores) + probabilities = jax.nn.softmax(sorted_logits) + + d2 = jnp.diff(jnp.diff(probabilities)) + d2 = jnp.abs(d2) + d2 = d2 / d2.sum(axis=-1, keepdims=True) + + cumulative_d2 = jnp.cumsum(d2, axis=-1) + sorted_indices_to_remove = cumulative_d2 > tfs + sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) + sorted_indices_to_remove = jnp.pad( + sorted_indices_to_remove, + (0, 2), + constant_values=True, + ) + + _, indices_to_remove = jax.lax.sort_key_val( + jnp.argsort(-scores), + sorted_indices_to_remove, + ) + return jnp.where(indices_to_remove, -jnp.inf, scores) + + + def sample_typical(scores: jnp.array) -> jnp.array: + probs = jax.nn.softmax(scores) + log_probs = jnp.log(probs) + + neg_entropy = jnp.nansum(probs * log_probs, axis=-1, keepdims=True) + entropy_deviation = jnp.abs(neg_entropy - log_probs) + + _, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs) + sorted_indices_to_remove = jnp.cumsum(sorted_logits, axis=-1) >= typical + sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=-1) + sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) + + _, indices_to_remove = jax.lax.sort_key_val( + jnp.argsort(entropy_deviation), + sorted_indices_to_remove, + ) + return jnp.where(indices_to_remove, -jnp.inf, scores) + + + def sample_temperature(scores: jnp.array) -> jnp.array: + return scores / temp + + + def sample_repetition_penalty( + logits: jnp.array, + tokens: jnp.array, + repetition_penalty, + generated_index, + rpslope, + rprange + ) -> jnp.array: + """ + This gets called to apply repetition penalty to the 1D array logits + using the provided 1D array of tokens to penalize + """ + rpslope = jnp.int32(rpslope) + rprange = jnp.int32(rprange) + + clipped_rprange = jax.lax.cond( + rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange + ) + + penalty_arange = jnp.roll( + jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]), + generated_index, + axis=-1, + ) + # Make a new array with the same length as the tokens array but with + # each element replaced by the value at the corresponding index in the + # logits array; e.g. + # if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1], + # then penalty_logits will be [77, 5, 3, 98, 3, 98, 5] + penalty_logits = jnp.take(logits, tokens) + + # Repetition penalty slope + def apply_slope(carry): + repetition_penalty, rprange = carry + _penalty = (penalty_arange / (rprange - 1)) * 2 - 1 + _penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1)) + _penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1) + return _penalty + + repetition_penalty = jax.lax.cond( + (rpslope != 0.0) + & (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash + apply_slope, + lambda carry: jnp.full(tokens.shape, carry[0]), + (repetition_penalty, rprange), + ) + # Divide positive values by repetition_penalty and multiply negative + # values by repetition_penalty (the academic publication that described + # this technique actually just only divided, but that would cause tokens + # with negative logits to become more likely, which is obviously wrong) + if koboldai_vars.use_alt_rep_pen: + penalty_logits = jnp.where( + penalty_arange >= 0, + penalty_logits - jnp.log(repetition_penalty), + penalty_logits, + ) + else: + penalty_logits = jnp.where( + penalty_arange >= 0, + jnp.where( + penalty_logits > 0, + penalty_logits / repetition_penalty, + penalty_logits * repetition_penalty, + ), + penalty_logits, + ) + # Finally, put those penalized logit values back into their original + # positions in the logits array + return logits.at[tokens].set(penalty_logits) + + for k in sampler_order: - logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), warpers.TopK.jax_static, lambda x: x, logits) - logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), warpers.TopA.jax_static, lambda x: x, logits) - logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), warpers.TopP.jax_static, lambda x: x, logits) - logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), warpers.TailFree.jax_static, lambda x: x, logits) - logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), warpers.Typical.jax_static, lambda x: x, logits) - logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), warpers.Temperature.jax_static, lambda x: x, logits) - logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: warpers.RepetitionPenalty.jax_static(*x), lambda x: x[0], (logits, *rpargs)) + logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), sample_top_k, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), sample_top_a, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), sample_top_p, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), sample_tail_free, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), sample_typical, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), sample_temperature, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: sample_repetition_penalty(*x), lambda x: x[0], (logits, *rpargs)) return jax.random.categorical(key, logits, -1).astype(jnp.uint32) pad_token_id = 50256 @@ -357,11 +521,10 @@ class PenalizingCausalTransformer(CausalTransformer): logits, ( generated, - # repetition_penalty, + repetition_penalty, generated_index, - # gen_length, - # rpslope, - # rprange, + rpslope, + rprange, ), **sampler_options, )