diff --git a/model.py b/model.py index e7cd60c7..ba12516a 100644 --- a/model.py +++ b/model.py @@ -2203,8 +2203,6 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel): # Get the model_type from the config or assume a model type if it isn't present try: - print("LMP:", self.get_local_model_path()) - print("M:", utils.koboldai_vars.model) model_config = AutoConfig.from_pretrained( self.get_local_model_path() or utils.koboldai_vars.model, revision=utils.koboldai_vars.revision, diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 855413a2..9a56ffd5 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -54,6 +54,8 @@ from mesh_transformer.transformer_shard import CausalTransformer, CausalTransfor from mesh_transformer.util import to_bf16 import time +import warpers +from warpers import Warper socketio = None @@ -213,6 +215,7 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) before picking one token using the modified logits ''' + """ # Top-k (keep only the k tokens with the highest logits and remove # the rest, by setting their logits to negative infinity) def top_k_filter(logits): @@ -344,6 +347,18 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra if k == 4 and typical < 1.0: logits = typical_filter(logits) if k == 5 and temp != 1.0: logits = temp_filter(logits) if k == 6 and rpargs[1] != 1.0: logits = apply_repetition_penalty_dynamic(logits, *rpargs) + """ + for sid in sampler_order: + warper = Warper.from_id(sid) + if not warper.value_is_valid(): + continue + + if warper == warpers.RepetitionPenalty: + print("ISREP", warper) + logits = warper.jax() + else: + print("AINTREP", warper) + logits = warper.jax_dynamic(logits, *rpargs) # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) @@ -356,6 +371,7 @@ def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) before picking one token using the modified logits ''' + """ # Top-k (keep only the k tokens with the highest logits and remove # the rest, by setting their logits to negative infinity) def top_k_filter(logits): @@ -486,6 +502,18 @@ def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits) logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits) logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: apply_repetition_penalty_static(*x), lambda x: x[0], (logits, *rpargs)) + """ + for sid in sampler_order: + warper = Warper.from_id(sid) + if not warper.value_is_valid(): + continue + + if warper == warpers.RepetitionPenalty: + print("ISREP", warper) + logits = warper.jax() + else: + print("AINTREP", warper) + logits = warper.jax_static(logits, *rpargs) # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) @@ -515,11 +543,11 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_ logits, ( generated, - repetition_penalty, + # repetition_penalty, generated_index, gen_length, - rpslope, - rprange, + # rpslope, + # rprange, ), **sampler_options, ) @@ -605,11 +633,11 @@ class PenalizingCausalTransformer(CausalTransformer): logits, ( generated, - repetition_penalty, + # repetition_penalty, generated_index, - gen_length, - rpslope, - rprange, + # gen_length, + # rpslope, + # rprange, ), **sampler_options, ) diff --git a/warpers.py b/warpers.py index f39e0b69..9c78ab2a 100644 --- a/warpers.py +++ b/warpers.py @@ -92,6 +92,10 @@ class Temperature(Warper): def jax(cls, scores: jnp.array) -> jnp.array: return scores / cls.temperature + @classmethod + def value_is_valid(cls) -> bool: + return cls.temperature != 1.0 + class TopP(Warper): """ @@ -140,6 +144,10 @@ class TopP(Warper): ) return jnp.where(indices_to_remove, -jnp.inf, scores) + @classmethod + def value_is_valid(cls) -> bool: + return cls.top_p < 1.0 + class TopK(Warper): """ @@ -173,6 +181,10 @@ class TopK(Warper): ) return np.where(indices_to_remove, -np.inf, scores) + @classmethod + def value_is_valid(cls) -> bool: + return cls.top_p > 0 + class TailFree(Warper): """ @@ -256,6 +268,10 @@ class TailFree(Warper): ) return np.where(indices_to_remove, -np.inf, scores) + @classmethod + def value_is_valid(cls) -> bool: + return cls.tfs < 1.0 + class Typical(Warper): """Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf""" @@ -332,6 +348,10 @@ class Typical(Warper): ) return np.where(indices_to_remove, -jnp.inf, scores) + @classmethod + def value_is_valid(cls) -> bool: + return cls.typical < 1.0 + class TopA(Warper): """ @@ -370,6 +390,10 @@ class TopA(Warper): probabilities < probs_max * probs_max * cls.top_a, -np.inf, scores ) + @classmethod + def value_is_valid(cls) -> bool: + return cls.top_a > 0.0 + class RepetitionPenalty(Warper): rep_pen: float = 1.0 @@ -543,3 +567,7 @@ class RepetitionPenalty(Warper): # positions in the logits array scores[tokens] = penalty_logits return scores + + @classmethod + def value_is_valid(cls) -> bool: + return cls.rep_pen != 1.0 \ No newline at end of file