From fdb2a7fa4ce87a586aae1590ded4630a1a03e48d Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Fri, 10 Jun 2022 22:28:20 -0400 Subject: [PATCH] Top-A sampling --- aiserver.py | 23 ++++++++++++++++++++++- bridge.lua | 2 ++ gensettings.py | 11 +++++++++++ static/application.js | 7 +++++++ templates/index.html | 2 +- tpu_mtj_backend.py | 42 ++++++++++++++++++++++++++++++++++++------ warpers.py | 29 +++++++++++++++++++++++++++++ 7 files changed, 108 insertions(+), 8 deletions(-) diff --git a/aiserver.py b/aiserver.py index fefab9b8..f14bbd77 100644 --- a/aiserver.py +++ b/aiserver.py @@ -212,6 +212,7 @@ class vars: temp = 0.5 # Default generator temperature top_p = 0.9 # Default generator top_p top_k = 0 # Default generator top_k + top_a = 0.0 # Default generator top-a tfs = 1.0 # Default generator tfs (tail-free sampling) typical = 1.0 # Default generator typical sampling threshold numseqs = 1 # Number of sequences to ask the generator to create @@ -577,6 +578,8 @@ def loadmodelsettings(): vars.tfs = js["tfs"] if("typical" in js): vars.typical = js["typical"] + if("top_a" in js): + vars.top_a = js["top_a"] if("rep_pen" in js): vars.rep_pen = js["rep_pen"] if("rep_pen_slope" in js): @@ -613,6 +616,7 @@ def savesettings(): js["top_k"] = vars.top_k js["tfs"] = vars.tfs js["typical"] = vars.typical + js["top_a"] = vars.top_a js["rep_pen"] = vars.rep_pen js["rep_pen_slope"] = vars.rep_pen_slope js["rep_pen_range"] = vars.rep_pen_range @@ -693,6 +697,8 @@ def processsettings(js): vars.tfs = js["tfs"] if("typical" in js): vars.typical = js["typical"] + if("top_a" in js): + vars.top_a = js["top_a"] if("rep_pen" in js): vars.rep_pen = js["rep_pen"] if("rep_pen_slope" in js): @@ -1379,7 +1385,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go # Patch transformers to use our custom logit warpers from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor - from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper + from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper, TopALogitsWarper def dynamic_processor_wrap(cls, field_name, var_name, cond=None): old_call = cls.__call__ @@ -1399,6 +1405,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go cls.__call__ = new_call dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0) dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0) + dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0) dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0) dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0) dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0) @@ -1445,6 +1452,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList: warper_list = LogitsProcessorList() warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1))) + warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1))) @@ -1814,6 +1822,7 @@ else: "top_k": int(vars.top_k), "tfs": float(vars.tfs), "typical": float(vars.typical), + "top_a": float(vars.top_a), "repetition_penalty": float(vars.rep_pen), "rpslope": float(vars.rep_pen_slope), "rprange": int(vars.rep_pen_range), @@ -2176,6 +2185,7 @@ def lua_has_setting(setting): "settopk", "settfs", "settypical", + "settopa", "setreppen", "setreppenslope", "setreppenrange", @@ -2195,6 +2205,7 @@ def lua_has_setting(setting): "top_k", "tfs", "typical", + "topa", "reppen", "reppenslope", "reppenrange", @@ -2229,6 +2240,7 @@ def lua_get_setting(setting): if(setting in ("settopk", "topk", "top_k")): return vars.top_k if(setting in ("settfs", "tfs")): return vars.tfs if(setting in ("settypical", "typical")): return vars.typical + if(setting in ("settopa", "topa")): return vars.top_a if(setting in ("setreppen", "reppen")): return vars.rep_pen if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range @@ -2264,6 +2276,7 @@ def lua_set_setting(setting, v): if(setting in ("settopk", "topk")): vars.top_k = v if(setting in ("settfs", "tfs")): vars.tfs = v if(setting in ("settypical", "typical")): vars.typical = v + if(setting in ("settopa", "topa")): vars.top_a = v if(setting in ("setreppen", "reppen")): vars.rep_pen = v if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v @@ -2688,6 +2701,11 @@ def get_message(msg): emit('from_server', {'cmd': 'setlabeltypical', 'data': msg['data']}, broadcast=True) settingschanged() refresh_settings() + elif(msg['cmd'] == 'settopa'): + vars.top_a = float(msg['data']) + emit('from_server', {'cmd': 'setlabeltopa', 'data': msg['data']}, broadcast=True) + settingschanged() + refresh_settings() elif(msg['cmd'] == 'setreppen'): vars.rep_pen = float(msg['data']) emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True) @@ -3748,6 +3766,7 @@ def sendtocolab(txt, min, max): 'top_k': vars.top_k, 'tfs': vars.tfs, 'typical': vars.typical, + 'topa': vars.top_a, 'numseqs': vars.numseqs, 'retfultxt': False } @@ -3885,6 +3904,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): top_k=vars.top_k, tfs=vars.tfs, typical=vars.typical, + top_a=vars.top_a, numseqs=vars.numseqs, repetition_penalty=vars.rep_pen, rpslope=vars.rep_pen_slope, @@ -4071,6 +4091,7 @@ def refresh_settings(): emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True) emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True) emit('from_server', {'cmd': 'updatetypical', 'data': vars.typical}, broadcast=True) + emit('from_server', {'cmd': 'updatetopa', 'data': vars.top_a}, broadcast=True) emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True) emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True) emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True) diff --git a/bridge.lua b/bridge.lua index ed0941c6..fc6c8823 100644 --- a/bridge.lua +++ b/bridge.lua @@ -867,6 +867,7 @@ return function(_python, _bridged) ---@field settopk integer ---@field settfs number ---@field settypical number + ---@field settopa number ---@field setreppen number ---@field setreppenslope number ---@field setreppenrange number @@ -884,6 +885,7 @@ return function(_python, _bridged) ---@field top_k integer ---@field tfs number ---@field typical number + ---@field topa number ---@field reppen number ---@field reppenslope number ---@field reppenrange number diff --git a/gensettings.py b/gensettings.py index e8d4e566..b3007c91 100644 --- a/gensettings.py +++ b/gensettings.py @@ -64,6 +64,17 @@ gensettingstf = [ "step": 0.05, "default": 1.0, "tooltip": "Alternative sampling method described in the paper \"Typical Decoding for Natural Language Generation\" (10.48550/ARXIV.2202.00666). The paper suggests 0.2 as a good value for this setting. Set this setting to 1 to disable its effect." + }, + { + "uitype": "slider", + "unit": "float", + "label": "Top a Sampling", + "id": "settopa", + "min": 0.0, + "max": 1.0, + "step": 0.01, + "default": 0.0, + "tooltip": "Alternative sampling method that reduces the randomness of the AI whenever the probability of one token is much higher than all the others. Higher values have a stronger effect. Set this setting to 0 to disable its effect." }, { "uitype": "slider", diff --git a/static/application.js b/static/application.js index 3b50281c..55487f76 100644 --- a/static/application.js +++ b/static/application.js @@ -2096,6 +2096,10 @@ $(document).ready(function(){ // Send current typical value to input $("#settypicalcur").val(msg.data); $("#settypical").val(parseFloat(msg.data)).trigger("change"); + } else if(msg.cmd == "updatetopa") { + // Send current top a value to input + $("#settopacur").val(msg.data); + $("#settopa").val(parseFloat(msg.data)).trigger("change"); } else if(msg.cmd == "updatereppen") { // Send current rep pen value to input $("#setreppencur").val(msg.data); @@ -2135,6 +2139,9 @@ $(document).ready(function(){ } else if(msg.cmd == "setlabeltypical") { // Update setting label with value from server $("#settypicalcur").val(msg.data); + } else if(msg.cmd == "setlabeltypical") { + // Update setting label with value from server + $("#settopa").val(msg.data); } else if(msg.cmd == "setlabelreppen") { // Update setting label with value from server $("#setreppencur").val(msg.data); diff --git a/templates/index.html b/templates/index.html index a3214ffa..690535f7 100644 --- a/templates/index.html +++ b/templates/index.html @@ -17,7 +17,7 @@ - + diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index fb2dc7ae..f66ad53c 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -70,6 +70,7 @@ def settings_callback() -> dict: "top_k": 0, "tfs": 1.0, "typical": 1.0, + "top_a": 0.0, "repetition_penalty": 1.0, "rpslope": 0.0, "rprange": 0, @@ -158,10 +159,10 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat logits[tokens] = penalty_logits return logits -def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0): +def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): ''' - This gets called by generate_loop_fn to apply a series of 5 filters - to the logits (top-k, then top-p, then TFS, then typical, then temperature) + This gets called by generate_loop_fn to apply a series of 6 filters + to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) before picking one token using the modified logits ''' # Top-k (keep only the k tokens with the highest logits and remove @@ -182,6 +183,20 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty return np.where(indices_to_remove, -np.inf, logits) if top_k > 0: logits = top_k_filter(logits) + # Top-a (remove all tokens that have softmax probability less than + # a*m^2 where m is the maximum softmax probability) + def top_a_filter(logits): + # Replace every element in the logits array + # with e (Euler's number) to the power of that element, and divide + # each element of the new array by the sum of the elements in the + # new array + probabilities = np.array(jax.nn.softmax(logits), copy=True) + # Find the largest probability + probs_max = probabilities.max() + # Remove tokens + return np.where(probabilities < probs_max * probs_max * top_a, -np.inf, logits) + if top_a > 0.0: + logits = top_a_filter(logits) # Top-p (after sorting the remaining tokens again in descending order of # logit, remove the ones that have cumulative softmax probability # greater than p) @@ -332,10 +347,10 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate # positions in the logits array return logits.at[tokens].set(penalty_logits) -def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0): +def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): ''' - This gets called by generate_loop_fn to apply a series of 5 filters - to the logits (top-k, then top-p, then TFS, then typical, then temperature) + This gets called by generate_loop_fn to apply a series of 6 filters + to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) before picking one token using the modified logits ''' # Top-k (keep only the k tokens with the highest logits and remove @@ -355,6 +370,19 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ ) return jnp.where(indices_to_remove, -jnp.inf, logits) logits = jax.lax.cond(top_k > 0, top_k_filter, lambda x: x, logits) + # Top-a (remove all tokens that have softmax probability less than + # a*m^2 where m is the maximum softmax probability) + def top_a_filter(logits): + # Replace every element in the logits array + # with e (Euler's number) to the power of that element, and divide + # each element of the new array by the sum of the elements in the + # new array + probabilities = jax.nn.softmax(logits) + # Find the largest probability + probs_max = probabilities.max() + # Remove tokens + return jnp.where(probabilities < probs_max * probs_max * top_a, -jnp.inf, logits) + logits = jax.lax.cond(top_a > 0.0, top_a_filter, lambda x: x, logits) # Top-p (after sorting the remaining tokens again in descending order of # logit, remove the ones that have cumulative softmax probability # greater than p) @@ -806,6 +834,7 @@ def infer_static( top_k=0, tfs=1.0, typical=1.0, + top_a=0.0, repetition_penalty=1.0, rpslope=0.0, rprange=0, @@ -829,6 +858,7 @@ def infer_static( "top_p": top_p * np.ones(total_batch), "tfs": tfs * np.ones(total_batch), "typical": typical * np.ones(total_batch), + "top_a": top_a * np.ones(total_batch), "repetition_penalty": repetition_penalty * np.ones(total_batch), "rpslope": rpslope * np.ones(total_batch), "rprange": np.full(total_batch, rprange, dtype=np.uint32), diff --git a/warpers.py b/warpers.py index 7c4f854b..9c1f88eb 100644 --- a/warpers.py +++ b/warpers.py @@ -148,3 +148,32 @@ class TypicalLogitsWarper(LogitsWarper): indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) scores = scores.masked_fill(indices_to_remove, self.filter_value) return scores + + +class TopALogitsWarper(LogitsWarper): + def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + top_a = float(top_a) + if top_a < 0 or top_a > 1.0: + raise ValueError(f"`top_a` has to be a float >= 0 and <= 1, but is {top_a}") + self.top_a = top_a + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if self.filter_value >= 1.0: + return scores + + sorted_logits, sorted_indices = torch.sort(scores, descending=True) + probs = sorted_logits.softmax(dim=-1) + + # Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept) + probs_max = probs[..., 0, None] + sorted_indices_to_remove = probs >= probs_max * probs_max * self.top_a + + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep + sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 + + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores