From fdb2a7fa4ce87a586aae1590ded4630a1a03e48d Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Fri, 10 Jun 2022 22:28:20 -0400 Subject: [PATCH 1/4] Top-A sampling --- aiserver.py | 23 ++++++++++++++++++++++- bridge.lua | 2 ++ gensettings.py | 11 +++++++++++ static/application.js | 7 +++++++ templates/index.html | 2 +- tpu_mtj_backend.py | 42 ++++++++++++++++++++++++++++++++++++------ warpers.py | 29 +++++++++++++++++++++++++++++ 7 files changed, 108 insertions(+), 8 deletions(-) diff --git a/aiserver.py b/aiserver.py index fefab9b8..f14bbd77 100644 --- a/aiserver.py +++ b/aiserver.py @@ -212,6 +212,7 @@ class vars: temp = 0.5 # Default generator temperature top_p = 0.9 # Default generator top_p top_k = 0 # Default generator top_k + top_a = 0.0 # Default generator top-a tfs = 1.0 # Default generator tfs (tail-free sampling) typical = 1.0 # Default generator typical sampling threshold numseqs = 1 # Number of sequences to ask the generator to create @@ -577,6 +578,8 @@ def loadmodelsettings(): vars.tfs = js["tfs"] if("typical" in js): vars.typical = js["typical"] + if("top_a" in js): + vars.top_a = js["top_a"] if("rep_pen" in js): vars.rep_pen = js["rep_pen"] if("rep_pen_slope" in js): @@ -613,6 +616,7 @@ def savesettings(): js["top_k"] = vars.top_k js["tfs"] = vars.tfs js["typical"] = vars.typical + js["top_a"] = vars.top_a js["rep_pen"] = vars.rep_pen js["rep_pen_slope"] = vars.rep_pen_slope js["rep_pen_range"] = vars.rep_pen_range @@ -693,6 +697,8 @@ def processsettings(js): vars.tfs = js["tfs"] if("typical" in js): vars.typical = js["typical"] + if("top_a" in js): + vars.top_a = js["top_a"] if("rep_pen" in js): vars.rep_pen = js["rep_pen"] if("rep_pen_slope" in js): @@ -1379,7 +1385,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go # Patch transformers to use our custom logit warpers from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor - from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper + from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper, TopALogitsWarper def dynamic_processor_wrap(cls, field_name, var_name, cond=None): old_call = cls.__call__ @@ -1399,6 +1405,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go cls.__call__ = new_call dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0) dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0) + dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0) dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0) dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0) dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0) @@ -1445,6 +1452,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList: warper_list = LogitsProcessorList() warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1))) + warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1))) @@ -1814,6 +1822,7 @@ else: "top_k": int(vars.top_k), "tfs": float(vars.tfs), "typical": float(vars.typical), + "top_a": float(vars.top_a), "repetition_penalty": float(vars.rep_pen), "rpslope": float(vars.rep_pen_slope), "rprange": int(vars.rep_pen_range), @@ -2176,6 +2185,7 @@ def lua_has_setting(setting): "settopk", "settfs", "settypical", + "settopa", "setreppen", "setreppenslope", "setreppenrange", @@ -2195,6 +2205,7 @@ def lua_has_setting(setting): "top_k", "tfs", "typical", + "topa", "reppen", "reppenslope", "reppenrange", @@ -2229,6 +2240,7 @@ def lua_get_setting(setting): if(setting in ("settopk", "topk", "top_k")): return vars.top_k if(setting in ("settfs", "tfs")): return vars.tfs if(setting in ("settypical", "typical")): return vars.typical + if(setting in ("settopa", "topa")): return vars.top_a if(setting in ("setreppen", "reppen")): return vars.rep_pen if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range @@ -2264,6 +2276,7 @@ def lua_set_setting(setting, v): if(setting in ("settopk", "topk")): vars.top_k = v if(setting in ("settfs", "tfs")): vars.tfs = v if(setting in ("settypical", "typical")): vars.typical = v + if(setting in ("settopa", "topa")): vars.top_a = v if(setting in ("setreppen", "reppen")): vars.rep_pen = v if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v @@ -2688,6 +2701,11 @@ def get_message(msg): emit('from_server', {'cmd': 'setlabeltypical', 'data': msg['data']}, broadcast=True) settingschanged() refresh_settings() + elif(msg['cmd'] == 'settopa'): + vars.top_a = float(msg['data']) + emit('from_server', {'cmd': 'setlabeltopa', 'data': msg['data']}, broadcast=True) + settingschanged() + refresh_settings() elif(msg['cmd'] == 'setreppen'): vars.rep_pen = float(msg['data']) emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True) @@ -3748,6 +3766,7 @@ def sendtocolab(txt, min, max): 'top_k': vars.top_k, 'tfs': vars.tfs, 'typical': vars.typical, + 'topa': vars.top_a, 'numseqs': vars.numseqs, 'retfultxt': False } @@ -3885,6 +3904,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): top_k=vars.top_k, tfs=vars.tfs, typical=vars.typical, + top_a=vars.top_a, numseqs=vars.numseqs, repetition_penalty=vars.rep_pen, rpslope=vars.rep_pen_slope, @@ -4071,6 +4091,7 @@ def refresh_settings(): emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True) emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True) emit('from_server', {'cmd': 'updatetypical', 'data': vars.typical}, broadcast=True) + emit('from_server', {'cmd': 'updatetopa', 'data': vars.top_a}, broadcast=True) emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True) emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True) emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True) diff --git a/bridge.lua b/bridge.lua index ed0941c6..fc6c8823 100644 --- a/bridge.lua +++ b/bridge.lua @@ -867,6 +867,7 @@ return function(_python, _bridged) ---@field settopk integer ---@field settfs number ---@field settypical number + ---@field settopa number ---@field setreppen number ---@field setreppenslope number ---@field setreppenrange number @@ -884,6 +885,7 @@ return function(_python, _bridged) ---@field top_k integer ---@field tfs number ---@field typical number + ---@field topa number ---@field reppen number ---@field reppenslope number ---@field reppenrange number diff --git a/gensettings.py b/gensettings.py index e8d4e566..b3007c91 100644 --- a/gensettings.py +++ b/gensettings.py @@ -64,6 +64,17 @@ gensettingstf = [ "step": 0.05, "default": 1.0, "tooltip": "Alternative sampling method described in the paper \"Typical Decoding for Natural Language Generation\" (10.48550/ARXIV.2202.00666). The paper suggests 0.2 as a good value for this setting. Set this setting to 1 to disable its effect." + }, + { + "uitype": "slider", + "unit": "float", + "label": "Top a Sampling", + "id": "settopa", + "min": 0.0, + "max": 1.0, + "step": 0.01, + "default": 0.0, + "tooltip": "Alternative sampling method that reduces the randomness of the AI whenever the probability of one token is much higher than all the others. Higher values have a stronger effect. Set this setting to 0 to disable its effect." }, { "uitype": "slider", diff --git a/static/application.js b/static/application.js index 3b50281c..55487f76 100644 --- a/static/application.js +++ b/static/application.js @@ -2096,6 +2096,10 @@ $(document).ready(function(){ // Send current typical value to input $("#settypicalcur").val(msg.data); $("#settypical").val(parseFloat(msg.data)).trigger("change"); + } else if(msg.cmd == "updatetopa") { + // Send current top a value to input + $("#settopacur").val(msg.data); + $("#settopa").val(parseFloat(msg.data)).trigger("change"); } else if(msg.cmd == "updatereppen") { // Send current rep pen value to input $("#setreppencur").val(msg.data); @@ -2135,6 +2139,9 @@ $(document).ready(function(){ } else if(msg.cmd == "setlabeltypical") { // Update setting label with value from server $("#settypicalcur").val(msg.data); + } else if(msg.cmd == "setlabeltypical") { + // Update setting label with value from server + $("#settopa").val(msg.data); } else if(msg.cmd == "setlabelreppen") { // Update setting label with value from server $("#setreppencur").val(msg.data); diff --git a/templates/index.html b/templates/index.html index a3214ffa..690535f7 100644 --- a/templates/index.html +++ b/templates/index.html @@ -17,7 +17,7 @@ - + diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index fb2dc7ae..f66ad53c 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -70,6 +70,7 @@ def settings_callback() -> dict: "top_k": 0, "tfs": 1.0, "typical": 1.0, + "top_a": 0.0, "repetition_penalty": 1.0, "rpslope": 0.0, "rprange": 0, @@ -158,10 +159,10 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat logits[tokens] = penalty_logits return logits -def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0): +def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): ''' - This gets called by generate_loop_fn to apply a series of 5 filters - to the logits (top-k, then top-p, then TFS, then typical, then temperature) + This gets called by generate_loop_fn to apply a series of 6 filters + to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) before picking one token using the modified logits ''' # Top-k (keep only the k tokens with the highest logits and remove @@ -182,6 +183,20 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty return np.where(indices_to_remove, -np.inf, logits) if top_k > 0: logits = top_k_filter(logits) + # Top-a (remove all tokens that have softmax probability less than + # a*m^2 where m is the maximum softmax probability) + def top_a_filter(logits): + # Replace every element in the logits array + # with e (Euler's number) to the power of that element, and divide + # each element of the new array by the sum of the elements in the + # new array + probabilities = np.array(jax.nn.softmax(logits), copy=True) + # Find the largest probability + probs_max = probabilities.max() + # Remove tokens + return np.where(probabilities < probs_max * probs_max * top_a, -np.inf, logits) + if top_a > 0.0: + logits = top_a_filter(logits) # Top-p (after sorting the remaining tokens again in descending order of # logit, remove the ones that have cumulative softmax probability # greater than p) @@ -332,10 +347,10 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate # positions in the logits array return logits.at[tokens].set(penalty_logits) -def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0): +def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): ''' - This gets called by generate_loop_fn to apply a series of 5 filters - to the logits (top-k, then top-p, then TFS, then typical, then temperature) + This gets called by generate_loop_fn to apply a series of 6 filters + to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) before picking one token using the modified logits ''' # Top-k (keep only the k tokens with the highest logits and remove @@ -355,6 +370,19 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ ) return jnp.where(indices_to_remove, -jnp.inf, logits) logits = jax.lax.cond(top_k > 0, top_k_filter, lambda x: x, logits) + # Top-a (remove all tokens that have softmax probability less than + # a*m^2 where m is the maximum softmax probability) + def top_a_filter(logits): + # Replace every element in the logits array + # with e (Euler's number) to the power of that element, and divide + # each element of the new array by the sum of the elements in the + # new array + probabilities = jax.nn.softmax(logits) + # Find the largest probability + probs_max = probabilities.max() + # Remove tokens + return jnp.where(probabilities < probs_max * probs_max * top_a, -jnp.inf, logits) + logits = jax.lax.cond(top_a > 0.0, top_a_filter, lambda x: x, logits) # Top-p (after sorting the remaining tokens again in descending order of # logit, remove the ones that have cumulative softmax probability # greater than p) @@ -806,6 +834,7 @@ def infer_static( top_k=0, tfs=1.0, typical=1.0, + top_a=0.0, repetition_penalty=1.0, rpslope=0.0, rprange=0, @@ -829,6 +858,7 @@ def infer_static( "top_p": top_p * np.ones(total_batch), "tfs": tfs * np.ones(total_batch), "typical": typical * np.ones(total_batch), + "top_a": top_a * np.ones(total_batch), "repetition_penalty": repetition_penalty * np.ones(total_batch), "rpslope": rpslope * np.ones(total_batch), "rprange": np.full(total_batch, rprange, dtype=np.uint32), diff --git a/warpers.py b/warpers.py index 7c4f854b..9c1f88eb 100644 --- a/warpers.py +++ b/warpers.py @@ -148,3 +148,32 @@ class TypicalLogitsWarper(LogitsWarper): indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) scores = scores.masked_fill(indices_to_remove, self.filter_value) return scores + + +class TopALogitsWarper(LogitsWarper): + def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + top_a = float(top_a) + if top_a < 0 or top_a > 1.0: + raise ValueError(f"`top_a` has to be a float >= 0 and <= 1, but is {top_a}") + self.top_a = top_a + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if self.filter_value >= 1.0: + return scores + + sorted_logits, sorted_indices = torch.sort(scores, descending=True) + probs = sorted_logits.softmax(dim=-1) + + # Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept) + probs_max = probs[..., 0, None] + sorted_indices_to_remove = probs >= probs_max * probs_max * self.top_a + + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep + sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 + + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores From 42b7a327b2440a3a48a969059beb06293a26e239 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Fri, 10 Jun 2022 22:34:14 -0400 Subject: [PATCH 2/4] Fix an unfortunate typo in top-a warper --- warpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/warpers.py b/warpers.py index 9c1f88eb..fb683f50 100644 --- a/warpers.py +++ b/warpers.py @@ -168,7 +168,7 @@ class TopALogitsWarper(LogitsWarper): # Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept) probs_max = probs[..., 0, None] - sorted_indices_to_remove = probs >= probs_max * probs_max * self.top_a + sorted_indices_to_remove = probs < probs_max * probs_max * self.top_a if self.min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep From 5c81374a48822cdaeb742d9aa007a152db30260d Mon Sep 17 00:00:00 2001 From: Henk Date: Sat, 11 Jun 2022 22:04:37 +0200 Subject: [PATCH 3/4] Top A for GooseAi --- aiserver.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aiserver.py b/aiserver.py index f14bbd77..b51fe538 100644 --- a/aiserver.py +++ b/aiserver.py @@ -4668,6 +4668,7 @@ def oairequest(txt, min, max): 'prompt': txt, 'max_tokens': vars.genamt, 'temperature': vars.temp, + 'top_a': vars.top_a, 'top_p': vars.top_p, 'top_k': vars.top_k, 'tfs': vars.tfs, From 66c0dda485d9be7df300e7e791d5b625f0c825e3 Mon Sep 17 00:00:00 2001 From: Henk Date: Sat, 11 Jun 2022 22:54:51 +0200 Subject: [PATCH 4/4] Hide (Broken) Chatbot Models Removing this option because they are currently unavailable. People who still have them can load them trough the load from file option. Once they have been retrained and reuploaded I will add the menu back. --- aiserver.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aiserver.py b/aiserver.py index b51fe538..6267aec2 100644 --- a/aiserver.py +++ b/aiserver.py @@ -90,7 +90,6 @@ mainmenu = [ ["Adventure Models", "adventurelist", ""], ["Novel Models", "novellist", ""], ["NSFW Models", "nsfwlist", ""], - ["Chatbot Models", "chatlist", ""], ["Untuned GPT-Neo/J", "gptneolist", ""], ["Untuned Fairseq Dense", "fsdlist", ""], ["Untuned OPT", "optlist", ""],