diff --git a/aiserver.py b/aiserver.py index 1a53c344..310067ad 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1727,8 +1727,6 @@ def patch_transformers(): dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0) dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0) dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0) - RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__ - RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__ class LuaLogitsProcessor(LogitsProcessor): @@ -1805,6 +1803,7 @@ def patch_transformers(): self.__warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) self.__warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1))) self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5)) + self.__warper_list.append(AdvancedRepetitionPenaltyLogitsProcessor()) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs): for k in vars.sampler_order: @@ -4617,7 +4616,7 @@ def _generate(txt, minimum, maximum, found_entries): gen_in, do_sample=True, max_length=int(2e9), - repetition_penalty=1.1, + repetition_penalty=1.0, bad_words_ids=vars.badwordsids, use_cache=True, num_return_sequences=numseqs diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 7b0f6807..1837fae6 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -176,7 +176,7 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat logits[tokens] = penalty_logits return logits -def kobold_sample_dynamic(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): +def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): ''' This gets called by generate_loop_fn to apply a series of 6 filters to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) @@ -315,6 +315,7 @@ def kobold_sample_dynamic(key, logits, sampler_order: Optional[np.ndarray] = Non # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) + logits = apply_repetition_penalty_dynamic(logits, *rpargs) return jax.random.categorical(key, logits, -1).astype(np.uint32) def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange): @@ -362,7 +363,7 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate # positions in the logits array return logits.at[tokens].set(penalty_logits) -def kobold_sample_static(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): +def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): ''' This gets called by generate_loop_fn to apply a series of 6 filters to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) @@ -500,6 +501,7 @@ def kobold_sample_static(key, logits, sampler_order: Optional[np.ndarray] = None # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) + logits = apply_repetition_penalty_static(logits, *rpargs) return jax.random.categorical(key, logits, -1).astype(jnp.uint32) pad_token_id = 50256 @@ -513,17 +515,6 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_ # Get the pseudo-random number generator key that will # be used by kobold_sample_dynamic to randomly pick a token sample_key, new_key = jax.random.split(sample_key, num=2) - # Apply repetition penalty to all tokens that are - # currently inside the "generated" array - logits = apply_repetition_penalty_dynamic( - logits, - generated, - repetition_penalty, - generated_index, - gen_length, - rpslope, - rprange, - ) # Remove any tokens in the badwords list by setting # their logits to negative infinity which effectively # makes their probabilities of being chosen zero @@ -535,6 +526,14 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_ next_token = kobold_sample_dynamic( sample_key, logits, + ( + generated, + repetition_penalty, + generated_index, + gen_length, + rpslope, + rprange, + ) **sampler_options, ) # Remember what token was picked @@ -606,18 +605,6 @@ class PenalizingCausalTransformer(CausalTransformer): assert logits.shape == (1, config["n_vocab"]) # Flatten it into a 1D array to make it easier to use logits = logits[0] - # Apply repetition penalty to all tokens that are - # currently inside the "generated" array - if repetition_penalty is not None: - logits = apply_repetition_penalty_static( - logits, - generated, - repetition_penalty, - generated_index, - gen_length, - rpslope, - rprange, - ) # Remove any tokens in the badwords list by setting # their logits to negative infinity which effectively # makes their probabilities of being chosen zero @@ -629,6 +616,14 @@ class PenalizingCausalTransformer(CausalTransformer): next_token = kobold_sample_static( sample_key, logits, + ( + generated, + repetition_penalty, + generated_index, + gen_length, + rpslope, + rprange, + ), **sampler_options, ) # Remember what token was picked diff --git a/warpers.py b/warpers.py index fb683f50..2eac074e 100644 --- a/warpers.py +++ b/warpers.py @@ -31,7 +31,7 @@ import torch from transformers import LogitsWarper, LogitsProcessor -class AdvancedRepetitionPenaltyLogitsProcessor(LogitsProcessor): +class AdvancedRepetitionPenaltyLogitsProcessor(LogitsWarper): def __init__(self, *args, **kwargs): pass