From 2db1f2f7bb4dea89fb69aff93f4f1207f2974ace Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 25 Jan 2022 15:05:21 -0500 Subject: [PATCH 1/7] AvrilAI-style repetition penalty test --- aiserver.py | 5 ++--- tpu_mtj_backend.py | 45 ++++++++++++++++++++------------------------- warpers.py | 2 +- 3 files changed, 23 insertions(+), 29 deletions(-) diff --git a/aiserver.py b/aiserver.py index 64470d1e..5be7a17a 100644 --- a/aiserver.py +++ b/aiserver.py @@ -722,8 +722,6 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0) dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0) dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0) - RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__ - RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__ class LuaLogitsProcessor(LogitsProcessor): @@ -767,6 +765,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TemperatureLogitsWarper(temperature=0.5)) + warper_list.append(AdvancedRepetitionPenaltyLogitsProcessor()) return warper_list def new_sample(self, *args, **kwargs): @@ -2771,7 +2770,7 @@ def _generate(txt, minimum, maximum, found_entries): do_sample=True, min_length=minimum, max_length=int(2e9), - repetition_penalty=1.1, + repetition_penalty=1.0, bad_words_ids=vars.badwordsids, use_cache=True, num_return_sequences=numseqs diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 653f8cf1..e7632eba 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -149,7 +149,7 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat logits[tokens] = penalty_logits return logits -def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): +def kobold_sample_dynamic(key, logits, rpargs, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): ''' This gets called by generate_loop_fn to apply a series of 4 filters to the logits (top-k, then top-p, then TFS, then temperature) before @@ -245,6 +245,7 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) + logits = apply_repetition_penalty_dynamic(logits, *rpargs) return jax.random.categorical(key, logits, -1).astype(np.uint32) def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange): @@ -292,7 +293,7 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate # positions in the logits array return logits.at[tokens].set(penalty_logits) -def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): +def kobold_sample_static(key, logits, rpargs, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): ''' This gets called by generate_loop_fn to apply a series of 4 filters to the logits (top-k, then top-p, then TFS, then temperature) before @@ -387,6 +388,7 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) + logits = apply_repetition_penalty_static(logits, *rpargs) return jax.random.categorical(key, logits, -1).astype(jnp.uint32) pad_token_id = 50256 @@ -400,17 +402,6 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_ # Get the pseudo-random number generator key that will # be used by kobold_sample_dynamic to randomly pick a token sample_key, new_key = jax.random.split(sample_key, num=2) - # Apply repetition penalty to all tokens that are - # currently inside the "generated" array - logits = apply_repetition_penalty_dynamic( - logits, - generated, - repetition_penalty, - generated_index, - gen_length, - rpslope, - rprange, - ) # Remove any tokens in the badwords list by setting # their logits to negative infinity which effectively # makes their probabilities of being chosen zero @@ -422,6 +413,14 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_ next_token = kobold_sample_dynamic( sample_key, logits, + ( + generated, + repetition_penalty, + generated_index, + gen_length, + rpslope, + rprange, + ) **sampler_options, ) # Remember what token was picked @@ -493,18 +492,6 @@ class PenalizingCausalTransformer(CausalTransformer): assert logits.shape == (1, config["n_vocab"]) # Flatten it into a 1D array to make it easier to use logits = logits[0] - # Apply repetition penalty to all tokens that are - # currently inside the "generated" array - if repetition_penalty is not None: - logits = apply_repetition_penalty_static( - logits, - generated, - repetition_penalty, - generated_index, - gen_length, - rpslope, - rprange, - ) # Remove any tokens in the badwords list by setting # their logits to negative infinity which effectively # makes their probabilities of being chosen zero @@ -516,6 +503,14 @@ class PenalizingCausalTransformer(CausalTransformer): next_token = kobold_sample_static( sample_key, logits, + ( + generated, + repetition_penalty, + generated_index, + gen_length, + rpslope, + rprange, + ), **sampler_options, ) # Remember what token was picked diff --git a/warpers.py b/warpers.py index 07670f6d..122bc1cd 100644 --- a/warpers.py +++ b/warpers.py @@ -31,7 +31,7 @@ import torch from transformers import LogitsWarper, LogitsProcessor -class AdvancedRepetitionPenaltyLogitsProcessor(LogitsProcessor): +class AdvancedRepetitionPenaltyLogitsProcessor(LogitsWarper): def __init__(self, *args, **kwargs): pass From 9eecb61feaa14a35aedb5401b3b0d4b84052b58e Mon Sep 17 00:00:00 2001 From: vfbd Date: Tue, 23 Aug 2022 14:52:45 -0400 Subject: [PATCH 2/7] Remove unused import from warpers.py --- warpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/warpers.py b/warpers.py index 2eac074e..488a901e 100644 --- a/warpers.py +++ b/warpers.py @@ -28,7 +28,7 @@ SOFTWARE. ''' import torch -from transformers import LogitsWarper, LogitsProcessor +from transformers import LogitsWarper class AdvancedRepetitionPenaltyLogitsProcessor(LogitsWarper): From 6ffaf43548b7a73b969b91eb5e16e0c6c86f6483 Mon Sep 17 00:00:00 2001 From: vfbd Date: Tue, 23 Aug 2022 15:10:21 -0400 Subject: [PATCH 3/7] Repetition penalty is now sampler #6 in the sampler order --- aiserver.py | 20 +++++++++++++++----- tpu_mtj_backend.py | 7 +++++-- utils.py | 2 +- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/aiserver.py b/aiserver.py index 310067ad..6539fcb8 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1806,7 +1806,10 @@ def patch_transformers(): self.__warper_list.append(AdvancedRepetitionPenaltyLogitsProcessor()) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs): - for k in vars.sampler_order: + sampler_order = vars.sampler_order[:] + if len(sampler_order) < 7: # Add repetition penalty at beginning if it's not present + sampler_order = [6] + sampler_order + for k in sampler_order: scores = self.__warper_list[k](input_ids, scores, *args, **kwargs) return scores @@ -1939,7 +1942,7 @@ def reset_model_settings(): vars.badwordsids = [] vars.fp32_model = False # Whether or not the most recently loaded HF model was in fp32 format vars.modeldim = -1 # Embedding dimension of your model (e.g. it's 4096 for GPT-J-6B and 2560 for GPT-Neo-2.7B) - vars.sampler_order = [0, 1, 2, 3, 4, 5] + vars.sampler_order = [6, 0, 1, 2, 3, 4, 5] vars.newlinemode = "n" vars.revision = None @@ -2550,8 +2553,11 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal vars.compiling = False def tpumtjgenerate_settings_callback() -> dict: + sampler_order = vars.sampler_order[:] + if len(sampler_order) < 7: # Add repetition penalty at beginning if it's not present + sampler_order = [6] + sampler_order return { - "sampler_order": vars.sampler_order, + "sampler_order": sampler_order, "top_p": float(vars.top_p), "temp": float(vars.temp), "top_k": int(vars.top_k), @@ -3658,12 +3664,16 @@ def get_message(msg): sendUSStatItems() elif(msg['cmd'] == 'samplers'): sampler_order = msg["data"] + sampler_order_min_length = 6 + sampler_order_max_length = 7 if(not isinstance(sampler_order, list)): raise ValueError(f"Sampler order must be a list, but got a {type(sampler_order)}") - if(len(sampler_order) != len(vars.sampler_order)): - raise ValueError(f"Sampler order must be a list of length {len(vars.sampler_order)}, but got a list of length {len(sampler_order)}") + if(not (sampler_order_min_length <= len(sampler_order) <= sampler_order_max_length)): + raise ValueError(f"Sampler order must be a list of length greater than or equal to {sampler_order_min_length} and less than or equal to {sampler_order_max_length}, but got a list of length {len(sampler_order)}") if(not all(isinstance(e, int) for e in sampler_order)): raise ValueError(f"Sampler order must be a list of ints, but got a list with at least one non-int element") + if(min(sampler_order) != 0 or max(sampler_order) != len(sampler_order) - 1 or len(set(sampler_order)) != len(sampler_order)): + raise ValueError(f"Sampler order list of length {len(sampler_order)} must be a permutation of the first {len(sampler_order)} nonnegative integers") vars.sampler_order = sampler_order settingschanged() elif(msg['cmd'] == 'list_model'): diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 1837fae6..19296e0a 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -312,10 +312,10 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra if k == 3 and tfs < 1.0: logits = tail_free_filter(logits) if k == 4 and typical < 1.0: logits = typical_filter(logits) if k == 5 and temp != 1.0: logits = temp_filter(logits) + if k == 6 and rpargs[1] != 1.0: logits = apply_repetition_penalty_dynamic(logits, *rpargs) # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) - logits = apply_repetition_penalty_dynamic(logits, *rpargs) return jax.random.categorical(key, logits, -1).astype(np.uint32) def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange): @@ -498,10 +498,10 @@ def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits) logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits) logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), apply_repetition_penalty_static, lambda x, *_: x, logits, *rpargs) # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) - logits = apply_repetition_penalty_static(logits, *rpargs) return jax.random.categorical(key, logits, -1).astype(jnp.uint32) pad_token_id = 50256 @@ -858,6 +858,9 @@ def infer_static( maps.thread_resources.env = thread_resources_env if sampler_order is None: sampler_order = utils.default_sampler_order.copy() + sampler_order = sampler_order[:] + if len(sampler_order) < 7: # Add repetition penalty at beginning if it's not present + sampler_order = [6] + sampler_order sampler_order = np.uint32(sampler_order) total_batch = 1 tokens = context diff --git a/utils.py b/utils.py index 7fd82072..76c04ea2 100644 --- a/utils.py +++ b/utils.py @@ -33,7 +33,7 @@ layers_module_names: Optional[List[str]] = None module_names: Optional[List[str]] = None named_buffers: Optional[List[tuple]] = None -default_sampler_order = [0, 1, 2, 3, 4, 5] +default_sampler_order = [6, 0, 1, 2, 3, 4, 5] #==================================================================# # Decorator to prevent a function's actions from being run until From aee4beb27a58c9f0dfb024e13f20da39fc7b9a48 Mon Sep 17 00:00:00 2001 From: vfbd Date: Tue, 23 Aug 2022 15:26:15 -0400 Subject: [PATCH 4/7] Fix the Show Field Budget toggle --- static/application.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/static/application.js b/static/application.js index 06c426b4..25564cdf 100644 --- a/static/application.js +++ b/static/application.js @@ -256,7 +256,7 @@ function addSetting(ob) { } }); - if (!$("#input-token-usage")[0].checked) { + if (!$("#setshowbudget")[0].checked) { for (const el of document.getElementsByClassName("input-token-usage")) { el.classList.add("hidden"); } From cbfe456409a82872396e12941f920e30ff708720 Mon Sep 17 00:00:00 2001 From: vfbd Date: Tue, 23 Aug 2022 15:30:07 -0400 Subject: [PATCH 5/7] Repetition penalty is now added to sampler list when loading from settings files --- aiserver.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/aiserver.py b/aiserver.py index 6539fcb8..2a26bc5e 100644 --- a/aiserver.py +++ b/aiserver.py @@ -963,7 +963,10 @@ def loadmodelsettings(): if("nobreakmodel" in js): vars.nobreakmodel = js["nobreakmodel"] if("sampler_order" in js): - vars.sampler_order = js["sampler_order"] + sampler_order = vars.sampler_order + if(len(sampler_order) < 7): + sampler_order = [6] + sampler_order + vars.sampler_order = sampler_order if("temp" in js): vars.temp = js["temp"] if("top_p" in js): @@ -1094,7 +1097,10 @@ def processsettings(js): if("andepth" in js): vars.andepth = js["andepth"] if("sampler_order" in js): - vars.sampler_order = js["sampler_order"] + sampler_order = vars.sampler_order + if(len(sampler_order) < 7): + sampler_order = [6] + sampler_order + vars.sampler_order = sampler_order if("temp" in js): vars.temp = js["temp"] if("top_p" in js): From ff9058896ebc7dd71781b022fe994a37de652aa4 Mon Sep 17 00:00:00 2001 From: vfbd Date: Tue, 23 Aug 2022 15:42:23 -0400 Subject: [PATCH 6/7] Add Repetition Penalty to Samplers menu --- static/application.js | 3 ++- static/custom.css | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/static/application.js b/static/application.js index 25564cdf..9107e161 100644 --- a/static/application.js +++ b/static/application.js @@ -1306,12 +1306,13 @@ function buildSamplerList(samplers) { "Tail-free Sampling", "Typical Sampling", "Temperature", + "Repetition Penalty", ] for(i=0; i\
\
\ -
"+samplers_lookup_table[samplers[i]]+"
\ +
"+(samplers[i] < samplers_lookup_table.length ? samplers_lookup_table[samplers[i]] : "Unknown sampler #" + samplers[i])+"
\
\
\ "); diff --git a/static/custom.css b/static/custom.css index af238dc7..d4bfe872 100644 --- a/static/custom.css +++ b/static/custom.css @@ -473,7 +473,7 @@ body.connected #popupfooter, #popupfooter.always-available { } #samplerslist { - height: 300px; + height: 310px; overflow-y: scroll; overflow-wrap: anywhere; } From 938e1eddf32a67fc1af127afd6eb7cfbdabde2e8 Mon Sep 17 00:00:00 2001 From: vfbd Date: Tue, 23 Aug 2022 18:13:46 -0400 Subject: [PATCH 7/7] Fix `jax.lax.cond` call --- tpu_mtj_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 19296e0a..effb3de0 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -498,7 +498,7 @@ def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits) logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits) logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits) - logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), apply_repetition_penalty_static, lambda x, *_: x, logits, *rpargs) + logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: apply_repetition_penalty_static(*x), lambda x: x[0], (logits, *rpargs)) # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution)