From 20e48b11d7d7ab7a30d47531ec63140278ae1b06 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Sun, 27 Mar 2022 16:25:50 -0400 Subject: [PATCH 1/4] Typical sampling --- aiserver.py | 24 ++++++++++++- bridge.lua | 2 ++ gensettings.py | 13 +++++++- static/application.js | 7 ++++ templates/index.html | 2 +- tpu_mtj_backend.py | 78 ++++++++++++++++++++++++++++++++++++++----- warpers.py | 52 ++++++++++++++++++++++++++++- 7 files changed, 166 insertions(+), 12 deletions(-) diff --git a/aiserver.py b/aiserver.py index b5704e9d..1bd6f36c 100644 --- a/aiserver.py +++ b/aiserver.py @@ -154,6 +154,7 @@ class vars: top_p = 0.9 # Default generator top_p top_k = 0 # Default generator top_k tfs = 1.0 # Default generator tfs (tail-free sampling) + typical = 1.0 # Default generator typical sampling threshold numseqs = 1 # Number of sequences to ask the generator to create gamestarted = False # Whether the game has started (disables UI elements) gamesaved = True # Whether or not current game is saved @@ -499,6 +500,8 @@ def loadmodelsettings(): vars.top_k = js["top_k"] if("tfs" in js): vars.tfs = js["tfs"] + if("typical" in js): + vars.typical = js["typical"] if("rep_pen" in js): vars.rep_pen = js["rep_pen"] if("rep_pen_slope" in js): @@ -534,6 +537,7 @@ def savesettings(): js["top_p"] = vars.top_p js["top_k"] = vars.top_k js["tfs"] = vars.tfs + js["typical"] = vars.typical js["rep_pen"] = vars.rep_pen js["rep_pen_slope"] = vars.rep_pen_slope js["rep_pen_range"] = vars.rep_pen_range @@ -600,6 +604,8 @@ def loadsettings(): vars.top_k = js["top_k"] if("tfs" in js): vars.tfs = js["tfs"] + if("typical" in js): + vars.typical = js["typical"] if("rep_pen" in js): vars.rep_pen = js["rep_pen"] if("rep_pen_slope" in js): @@ -1172,7 +1178,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go # Patch transformers to use our custom logit warpers from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor - from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper + from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper def dynamic_processor_wrap(cls, field_name, var_name, cond=None): old_call = cls.__call__ @@ -1194,6 +1200,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0) dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0) dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0) + dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0) dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0) RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__ RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__ @@ -1239,6 +1246,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) + warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TemperatureLogitsWarper(temperature=0.5)) return warper_list @@ -1540,6 +1548,7 @@ else: "temp": float(vars.temp), "top_k": int(vars.top_k), "tfs": float(vars.tfs), + "typical": float(vars.typical), "repetition_penalty": float(vars.rep_pen), "rpslope": float(vars.rep_pen_slope), "rprange": int(vars.rep_pen_range), @@ -1901,6 +1910,7 @@ def lua_has_setting(setting): "settopp", "settopk", "settfs", + "settypical", "setreppen", "setreppenslope", "setreppenrange", @@ -1919,6 +1929,7 @@ def lua_has_setting(setting): "topk", "top_k", "tfs", + "typical", "reppen", "reppenslope", "reppenrange", @@ -1952,6 +1963,7 @@ def lua_get_setting(setting): if(setting in ("settopp", "topp", "top_p")): return vars.top_p if(setting in ("settopk", "topk", "top_k")): return vars.top_k if(setting in ("settfs", "tfs")): return vars.tfs + if(setting in ("settypical", "typical")): return vars.typical if(setting in ("setreppen", "reppen")): return vars.rep_pen if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range @@ -1986,6 +1998,7 @@ def lua_set_setting(setting, v): if(setting in ("settopp", "topp")): vars.top_p = v if(setting in ("settopk", "topk")): vars.top_k = v if(setting in ("settfs", "tfs")): vars.tfs = v + if(setting in ("settypical", "typical")): vars.typical = v if(setting in ("setreppen", "reppen")): vars.rep_pen = v if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v @@ -2382,6 +2395,11 @@ def get_message(msg): emit('from_server', {'cmd': 'setlabeltfs', 'data': msg['data']}, broadcast=True) settingschanged() refresh_settings() + elif(msg['cmd'] == 'settypical'): + vars.typical = float(msg['data']) + emit('from_server', {'cmd': 'setlabeltypical', 'data': msg['data']}, broadcast=True) + settingschanged() + refresh_settings() elif(msg['cmd'] == 'setreppen'): vars.rep_pen = float(msg['data']) emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True) @@ -3442,6 +3460,7 @@ def sendtocolab(txt, min, max): 'top_p': vars.top_p, 'top_k': vars.top_k, 'tfs': vars.tfs, + 'typical': vars.typical, 'numseqs': vars.numseqs, 'retfultxt': False } @@ -3578,6 +3597,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): top_p=vars.top_p, top_k=vars.top_k, tfs=vars.tfs, + typical=vars.typical, numseqs=vars.numseqs, repetition_penalty=vars.rep_pen, rpslope=vars.rep_pen_slope, @@ -3763,6 +3783,7 @@ def refresh_settings(): emit('from_server', {'cmd': 'updatetopp', 'data': vars.top_p}, broadcast=True) emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True) emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True) + emit('from_server', {'cmd': 'updatetypical', 'data': vars.typical}, broadcast=True) emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True) emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True) emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True) @@ -4341,6 +4362,7 @@ def oairequest(txt, min, max): 'top_p': vars.top_p, 'top_k': vars.top_k, 'tfs': vars.tfs, + 'typical': vars.typical, 'repetition_penalty': vars.rep_pen, 'repetition_penalty_slope': vars.rep_pen_slope, 'repetition_penalty_range': vars.rep_pen_range, diff --git a/bridge.lua b/bridge.lua index b46977c5..796bc33f 100644 --- a/bridge.lua +++ b/bridge.lua @@ -866,6 +866,7 @@ return function(_python, _bridged) ---@field settopp number ---@field settopk integer ---@field settfs number + ---@field settypical number ---@field setreppen number ---@field setreppenslope number ---@field setreppenrange number @@ -882,6 +883,7 @@ return function(_python, _bridged) ---@field top_p number ---@field top_k integer ---@field tfs number + ---@field typical number ---@field reppen number ---@field reppenslope number ---@field reppenrange number diff --git a/gensettings.py b/gensettings.py index 842ff329..e8d4e566 100644 --- a/gensettings.py +++ b/gensettings.py @@ -51,8 +51,19 @@ gensettingstf = [ "min": 0.0, "max": 1.0, "step": 0.05, - "default": 0.0, + "default": 1.0, "tooltip": "Alternative sampling method; it is recommended to disable top_p and top_k (set top_p to 1 and top_k to 0) if using this. 0.95 is thought to be a good value. (Put this value on 1 to disable its effect)" + }, + { + "uitype": "slider", + "unit": "float", + "label": "Typical Sampling", + "id": "settypical", + "min": 0.0, + "max": 1.0, + "step": 0.05, + "default": 1.0, + "tooltip": "Alternative sampling method described in the paper \"Typical Decoding for Natural Language Generation\" (10.48550/ARXIV.2202.00666). The paper suggests 0.2 as a good value for this setting. Set this setting to 1 to disable its effect." }, { "uitype": "slider", diff --git a/static/application.js b/static/application.js index 88637075..5db71921 100644 --- a/static/application.js +++ b/static/application.js @@ -2041,6 +2041,10 @@ $(document).ready(function(){ // Send current tfs value to input $("#settfs").val(parseFloat(msg.data)); $("#settfscur").html(msg.data); + } else if(msg.cmd == "updatetypical") { + // Send current typical value to input + $("#settypical").val(parseFloat(msg.data)); + $("#settypicalcur").html(msg.data); } else if(msg.cmd == "updatereppen") { // Send current rep pen value to input $("#setreppen").val(parseFloat(msg.data)); @@ -2077,6 +2081,9 @@ $(document).ready(function(){ } else if(msg.cmd == "setlabeltfs") { // Update setting label with value from server $("#settfscur").html(msg.data); + } else if(msg.cmd == "setlabeltypical") { + // Update setting label with value from server + $("#settypicalcur").html(msg.data); } else if(msg.cmd == "setlabelreppen") { // Update setting label with value from server $("#setreppencur").html(msg.data); diff --git a/templates/index.html b/templates/index.html index 6db5c093..f6f66c19 100644 --- a/templates/index.html +++ b/templates/index.html @@ -17,7 +17,7 @@ - + diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 364d39d5..202e24dc 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -67,6 +67,7 @@ def settings_callback() -> dict: "temp": 0.5, "top_k": 0, "tfs": 1.0, + "typical": 1.0, "repetition_penalty": 1.0, "rpslope": 0.0, "rprange": 0, @@ -155,11 +156,11 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat logits[tokens] = penalty_logits return logits -def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): +def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0): ''' - This gets called by generate_loop_fn to apply a series of 4 filters - to the logits (top-k, then top-p, then TFS, then temperature) before - picking one token using the modified logits + This gets called by generate_loop_fn to apply a series of 5 filters + to the logits (top-k, then top-p, then TFS, then typical, then temperature) + before picking one token using the modified logits ''' # Top-k (keep only the k tokens with the highest logits and remove # the rest, by setting their logits to negative infinity) @@ -246,6 +247,36 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): return np.where(indices_to_remove, -np.inf, logits) if tfs < 1.0: logits = tail_free_filter(logits) + # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf + def typical_filter(logits): + # Compute softmax probabilities and the natural logarithms of them + probs = jax.nn.softmax(logits) + log_probs = np.log(probs) + # Compute the negative of entropy, which is the sum of p*ln(p) for all p + # in the set of softmax probabilities of the logits + neg_entropy = (probs * log_probs).sum(axis=-1, keepdims=True) + # Determine absolute difference between the negative entropy and the + # log probabilities + entropy_deviation = np.abs(neg_entropy - log_probs) + # Keep certain tokens such that the sum of the entropy_deviation of the + # kept tokens is the smallest possible value such that the sum of the + # softmax probabilities of the kept tokens is at least the threshold + # value (by sorting the tokens in ascending order of entropy_deviation + # and then keeping the smallest possible number of tokens from the + # beginning such that sum of softmax probabilities is at or above the + # threshold) + _, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs) + sorted_indices_to_remove = np.cumsum(sorted_logits, axis=-1) >= typical + sorted_indices_to_remove = np.roll(sorted_indices_to_remove, 1, axis=-1) + sorted_indices_to_remove[0] = False + # Unsort and remove + _, indices_to_remove = jax.lax.sort_key_val( + jnp.argsort(entropy_deviation), + sorted_indices_to_remove, + ) + return np.where(indices_to_remove, -jnp.inf, logits) + if typical < 1.0: + logits = typical_filter(logits) # Temperature (just divide the logits by the temperature) logits /= temp # Finally, pick one token using the softmax thingy again (it gives @@ -298,11 +329,11 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate # positions in the logits array return logits.at[tokens].set(penalty_logits) -def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): +def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0): ''' - This gets called by generate_loop_fn to apply a series of 4 filters - to the logits (top-k, then top-p, then TFS, then temperature) before - picking one token using the modified logits + This gets called by generate_loop_fn to apply a series of 5 filters + to the logits (top-k, then top-p, then TFS, then typical, then temperature) + before picking one token using the modified logits ''' # Top-k (keep only the k tokens with the highest logits and remove # the rest, by setting their logits to negative infinity) @@ -386,6 +417,35 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): ) return jnp.where(indices_to_remove, -jnp.inf, logits) logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits) + # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf + def typical_filter(logits): + # Compute softmax probabilities and the natural logarithms of them + probs = jax.nn.softmax(logits) + log_probs = jnp.log(probs) + # Compute the negative of entropy, which is the sum of p*ln(p) for all p + # in the set of softmax probabilities of the logits + neg_entropy = (probs * log_probs).sum(axis=-1, keepdims=True) + # Determine absolute difference between the negative entropy and the + # log probabilities + entropy_deviation = jnp.abs(neg_entropy - log_probs) + # Keep certain tokens such that the sum of the entropy_deviation of the + # kept tokens is the smallest possible value such that the sum of the + # softmax probabilities of the kept tokens is at least the threshold + # value (by sorting the tokens in ascending order of entropy_deviation + # and then keeping the smallest possible number of tokens from the + # beginning such that sum of softmax probabilities is at or above the + # threshold) + _, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs) + sorted_indices_to_remove = jnp.cumsum(sorted_logits, axis=-1) >= typical + sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=-1) + sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) + # Unsort and remove + _, indices_to_remove = jax.lax.sort_key_val( + jnp.argsort(entropy_deviation), + sorted_indices_to_remove, + ) + return jnp.where(indices_to_remove, -jnp.inf, logits) + logits = jax.lax.cond(typical < 1.0, typical_filter, lambda x: x, logits) # Temperature (just divide the logits by the temperature) def temp_filter(logits): return logits / temp @@ -742,6 +802,7 @@ def infer_static( temp=0.5, top_k=0, tfs=1.0, + typical=1.0, repetition_penalty=1.0, rpslope=0.0, rprange=0, @@ -764,6 +825,7 @@ def infer_static( "temp": temp * np.ones(total_batch), "top_p": top_p * np.ones(total_batch), "tfs": tfs * np.ones(total_batch), + "typical": typical * np.ones(total_batch), "repetition_penalty": repetition_penalty * np.ones(total_batch), "rpslope": rpslope * np.ones(total_batch), "rprange": np.full(total_batch, rprange, dtype=np.uint32), diff --git a/warpers.py b/warpers.py index 07670f6d..bedd3445 100644 --- a/warpers.py +++ b/warpers.py @@ -62,7 +62,7 @@ class TailFreeLogitsWarper(LogitsWarper): def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): tfs = float(tfs) if tfs < 0 or tfs > 1.0: - raise ValueError(f"`tfs` has to be a float > 0 and < 1, but is {tfs}") + raise ValueError(f"`tfs` has to be a float >= 0 and <= 1, but is {tfs}") self.tfs = tfs self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep @@ -98,3 +98,53 @@ class TailFreeLogitsWarper(LogitsWarper): indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) scores = scores.masked_fill(indices_to_remove, self.filter_value) return scores + + +class TypicalLogitsWarper(LogitsWarper): + ''' + Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf + ''' + + def __init__(self, typical: float, filter_value: -float("Inf"), min_tokens_to_keep: int = 1): + typical = float(typical) + if typical < 0 or typical > 1.0: + raise ValueError(f"`typical` has to be a float >= 0 and <= 1, but is {typical}") + self.typical = typical + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if self.filter_value >= 1.0: + return scores + + # Compute softmax probabilities and the natural logarithms of them + probs = scores.softmax(dim=-1) + log_probs = probs.log() + + # Compute the negative of entropy, which is the sum of p*ln(p) for all p + # in the set of softmax probabilities of the logits + neg_entropy = (probs * log_probs).sum(dim=-1, keepdim=True) + + # Determine absolute difference between the negative entropy and the + # log probabilities + entropy_deviation = (neg_entropy - log_probs).abs() + + # Keep certain tokens such that the sum of the entropy_deviation of the + # kept tokens is the smallest possible value such that the sum of the + # softmax probabilities of the kept tokens is at least the threshold + # value (by sorting the tokens in ascending order of entropy_deviation + # and then keeping the smallest possible number of tokens from the + # beginning such that sum of softmax probabilities is at or above the + # threshold) + _, sorted_indices = torch.sort(entropy_deviation) + sorted_logits = probs.gather(-1, sorted_indices) + sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= self.typical + sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dim=-1) + + min_tokens_to_keep = max(self.min_tokens_to_keep, 1) + # Keep at least min_tokens_to_keep + sorted_indices_to_remove[..., : min_tokens_to_keep] = 0 + + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores From d5989d4c62c45e2f170118355543eb8928ac148e Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Sun, 27 Mar 2022 16:57:12 -0400 Subject: [PATCH 2/4] Hide division by zero warning in JAX typical filter This warning happens when `np.log` gets an input containing zeros. In that case, NumPy will throw a warning and output negative infinity. Negative infinity is the correct behaviour here, so we can safely ignore the warning. --- tpu_mtj_backend.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 202e24dc..000f1713 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -247,11 +247,12 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty return np.where(indices_to_remove, -np.inf, logits) if tfs < 1.0: logits = tail_free_filter(logits) - # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf + # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf) def typical_filter(logits): # Compute softmax probabilities and the natural logarithms of them probs = jax.nn.softmax(logits) - log_probs = np.log(probs) + with np.errstate(divide="ignore"): + log_probs = np.log(probs) # Compute the negative of entropy, which is the sum of p*ln(p) for all p # in the set of softmax probabilities of the logits neg_entropy = (probs * log_probs).sum(axis=-1, keepdims=True) @@ -417,7 +418,7 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ ) return jnp.where(indices_to_remove, -jnp.inf, logits) logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits) - # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf + # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf) def typical_filter(logits): # Compute softmax probabilities and the natural logarithms of them probs = jax.nn.softmax(logits) From bbd0a83fef192ddaad59eaa33cefff4be608b99a Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Sun, 27 Mar 2022 16:59:23 -0400 Subject: [PATCH 3/4] Fix `TypicalLogitsWarper` argument typing --- warpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/warpers.py b/warpers.py index bedd3445..0eda0eda 100644 --- a/warpers.py +++ b/warpers.py @@ -105,7 +105,7 @@ class TypicalLogitsWarper(LogitsWarper): Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf ''' - def __init__(self, typical: float, filter_value: -float("Inf"), min_tokens_to_keep: int = 1): + def __init__(self, typical: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): typical = float(typical) if typical < 0 or typical > 1.0: raise ValueError(f"`typical` has to be a float >= 0 and <= 1, but is {typical}") From e2cd49d552ad0df58e7397cc60dfd4e50a774269 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Sun, 27 Mar 2022 17:08:57 -0400 Subject: [PATCH 4/4] Typo fix in `TypicalLogitsWarper` --- warpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/warpers.py b/warpers.py index 0eda0eda..bb12cbb0 100644 --- a/warpers.py +++ b/warpers.py @@ -139,7 +139,7 @@ class TypicalLogitsWarper(LogitsWarper): _, sorted_indices = torch.sort(entropy_deviation) sorted_logits = probs.gather(-1, sorted_indices) sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= self.typical - sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dim=-1) + sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dims=-1) min_tokens_to_keep = max(self.min_tokens_to_keep, 1) # Keep at least min_tokens_to_keep