From 9b1138bafa5b0e68750f783579e14d8bfb624bd7 Mon Sep 17 00:00:00 2001 From: ebolam Date: Tue, 10 Jan 2023 08:45:55 -0500 Subject: [PATCH] Added in alternative rep pen calculation (log instead of linear application) as an option. --- aiserver.py | 2 +- gensettings.py | 16 ++++++++++++++ koboldai_settings.py | 1 + tpu_mtj_backend.py | 51 ++++++++++++++++++++++++++++---------------- warpers.py | 5 ++++- 5 files changed, 55 insertions(+), 20 deletions(-) diff --git a/aiserver.py b/aiserver.py index b2dc22dc..7c32e97b 100644 --- a/aiserver.py +++ b/aiserver.py @@ -2180,7 +2180,7 @@ def patch_transformers(): return old_call(self, *args, **kwargs) return args[1] cls.__call__ = new_call - dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0) + dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range", "use_alt_rep_pen"), ("rep_pen", "rep_pen_slope", "rep_pen_range", "use_alt_rep_pen"), cond=lambda x: x[0] != 1.0) dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0) dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0) dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0) diff --git a/gensettings.py b/gensettings.py index 214d2f45..a57e1de5 100644 --- a/gensettings.py +++ b/gensettings.py @@ -160,6 +160,22 @@ gensettingstf = [ "name": "rep_pen_slope", "ui_level": 1 }, + { + "uitype": "toggle", + "unit": "bool", + "label": "Alt Rep Pen", + "id": "use_alt_rep_pen", + "min": 0, + "max": 1, + "step": 1, + "default": 0, + "tooltip": "Applies repetition penalty as a logarithmic modifier rather than a linear modifier.", + "menu_path": "Settings", + "sub_path": "Repetition", + "classname": "model", + "name": "use_alt_rep_pen", + "ui_level": 2 + }, { "uitype": "slider", "unit": "int", diff --git a/koboldai_settings.py b/koboldai_settings.py index 9feee060..5031bcea 100644 --- a/koboldai_settings.py +++ b/koboldai_settings.py @@ -720,6 +720,7 @@ class model_settings(settings): self.horde_wait_time = 0 self.horde_queue_position = 0 self.horde_queue_size = 0 + self.use_alt_rep_pen = False diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 71fa1f87..02754d95 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -201,15 +201,23 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat # values by repetition_penalty (the academic publication that described # this technique actually just only divided, but that would cause tokens # with negative logits to become more likely, which is obviously wrong) - penalty_logits = np.where( - penalty_arange >= 0, - np.where( - penalty_logits > 0, - penalty_logits/repetition_penalty, - penalty_logits*repetition_penalty, - ), - penalty_logits, - ) + if koboldai_vars.use_alt_rep_pen: + penalty_logits = np.where( + penalty_arange >= 0, + penalty_logits - np.log(repetition_penalty), + penalty_logits, + ) + + else: + penalty_logits = np.where( + penalty_arange >= 0, + np.where( + penalty_logits > 0, + penalty_logits/repetition_penalty, + penalty_logits*repetition_penalty, + ), + penalty_logits, + ) # Finally, put those penalized logit values back into their original # positions in the logits array logits[tokens] = penalty_logits @@ -389,15 +397,22 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate # values by repetition_penalty (the academic publication that described # this technique actually just only divided, but that would cause tokens # with negative logits to become more likely, which is obviously wrong) - penalty_logits = jnp.where( - penalty_arange >= 0, - jnp.where( - penalty_logits > 0, - penalty_logits/repetition_penalty, - penalty_logits*repetition_penalty, - ), - penalty_logits, - ) + if koboldai_vars.use_alt_rep_pen: + penalty_logits = jnp.where( + penalty_arange >= 0, + penalty_logits - jnp.log(repetition_penalty), + penalty_logits, + ) + else: + penalty_logits = jnp.where( + penalty_arange >= 0, + jnp.where( + penalty_logits > 0, + penalty_logits/repetition_penalty, + penalty_logits*repetition_penalty, + ), + penalty_logits, + ) # Finally, put those penalized logit values back into their original # positions in the logits array return logits.at[tokens].set(penalty_logits) diff --git a/warpers.py b/warpers.py index 488a901e..ef46f318 100644 --- a/warpers.py +++ b/warpers.py @@ -51,7 +51,10 @@ class AdvancedRepetitionPenaltyLogitsProcessor(LogitsWarper): self.penalty = _penalty[..., -clipped_penalty_range:] score = torch.gather(scores, 1, input_ids) - score = torch.where(score <= 0, score * self.penalty, score / self.penalty) + if self.use_alt_rep_pen: + score = score - torch.log(self.penalty) + else: + score = torch.where(score <= 0, score * self.penalty, score / self.penalty) scores.scatter_(1, input_ids, score) return scores