diff --git a/aiserver.py b/aiserver.py index d353dee4..1182b36f 100644 --- a/aiserver.py +++ b/aiserver.py @@ -23,6 +23,7 @@ import packaging import contextlib import traceback import threading +from collections.abc import Iterable from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List import requests @@ -100,6 +101,8 @@ class vars: genamt = 80 # Amount of text for each action to generate ikgen = 200 # Number of characters for InferKit to generate rep_pen = 1.1 # Default generator repetition_penalty + rep_pen_slope = 0.0 # Default generator repetition penalty slope + rep_pen_range = 0 # Default generator repetition penalty range temp = 0.5 # Default generator temperature top_p = 0.9 # Default generator top_p top_k = 0 # Default generator top_k @@ -696,65 +699,32 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme # Patch transformers to use our custom logit warpers from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor + from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper def dynamic_processor_wrap(cls, field_name, var_name, cond=None): old_call = cls.__call__ def new_call(self, *args, **kwargs): - setattr(self, field_name, getattr(vars, var_name)) + if(not isinstance(field_name, str) and isinstance(field_name, Iterable)): + conds = [] + for f, v in zip(field_name, var_name): + conds.append(getattr(vars, v)) + setattr(self, f, conds[-1]) + else: + conds = getattr(vars, var_name) + setattr(self, field_name, conds) assert len(args) == 2 - if(cond is None or cond(getattr(vars, var_name))): + if(cond is None or cond(conds)): return old_call(self, *args, **kwargs) return args[1] cls.__call__ = new_call - dynamic_processor_wrap(RepetitionPenaltyLogitsProcessor, "penalty", "rep_pen", cond=lambda x: x != 1.0) + dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0) dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0) dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0) + dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0) dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0) + RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__ + RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__ - class TailFreeLogitsWarper(LogitsWarper): - - def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): - tfs = float(tfs) - if tfs < 0 or tfs > 1.0: - raise ValueError(f"`tfs` has to be a float > 0 and < 1, but is {tfs}") - self.tfs = tfs - self.filter_value = filter_value - self.min_tokens_to_keep = min_tokens_to_keep - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - self.tfs = vars.tfs - - if self.filter_value >= 1.0: - return scores - sorted_logits, sorted_indices = torch.sort(scores, descending=True) - probs = sorted_logits.softmax(dim=-1) - - # Compute second derivative normalized CDF - d2 = probs.diff().diff().abs() - normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True) - normalized_d2_cdf = normalized_d2.cumsum(dim=-1) - - # Remove tokens with CDF value above the threshold (token with 0 are kept) - sorted_indices_to_remove = normalized_d2_cdf > self.tfs - - # Centre the distribution around the cutoff as in the original implementation of the algorithm - sorted_indices_to_remove = torch.cat( - ( - torch.zeros(scores.shape[0], 1, dtype=torch.bool, device=scores.device), - sorted_indices_to_remove, - torch.ones(scores.shape[0], 1, dtype=torch.bool, device=scores.device), - ), - dim=-1, - ) - - if self.min_tokens_to_keep > 1: - # Keep at least min_tokens_to_keep - sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 - - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - scores = scores.masked_fill(indices_to_remove, self.filter_value) - return scores - class LuaLogitsProcessor(LogitsProcessor): def __init__(self): @@ -1072,6 +1042,8 @@ else: "top_k": int(vars.top_k), "tfs": float(vars.tfs), "repetition_penalty": float(vars.rep_pen), + "rpslope": float(vars.rep_pen_slope), + "rprange": int(vars.rep_pen_range), } # If we're running Colab or OAI, we still need a tokenizer. @@ -1418,6 +1390,8 @@ def lua_has_setting(setting): "settopk", "settfs", "setreppen", + "setreppenslope", + "setreppenrange", "settknmax", "setwidepth", "setuseprompt", @@ -1433,6 +1407,8 @@ def lua_has_setting(setting): "top_k", "tfs", "reppen", + "reppenslope", + "reppenrange", "tknmax", "widepth", "useprompt", @@ -1464,6 +1440,8 @@ def lua_get_setting(setting): if(setting in ("settopk", "topk", "top_k")): return vars.top_k if(setting in ("settfs", "tfs")): return vars.tfs if(setting in ("setreppen", "reppen")): return vars.rep_pen + if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope + if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range if(setting in ("settknmax", "tknmax")): return vars.max_length if(setting == "anotedepth"): return vars.andepth if(setting in ("setwidepth", "widepth")): return vars.widepth @@ -1495,6 +1473,8 @@ def lua_set_setting(setting, v): if(setting in ("settopk", "topk")): vars.top_k = v if(setting in ("settfs", "tfs")): vars.tfs = v if(setting in ("setreppen", "reppen")): vars.rep_pen = v + if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v + if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v if(setting in ("settknmax", "tknmax")): vars.max_length = v; return True if(setting == "anotedepth"): vars.andepth = v; return True if(setting in ("setwidepth", "widepth")): vars.widepth = v; return True @@ -1881,6 +1861,16 @@ def get_message(msg): emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True) settingschanged() refresh_settings() + elif(msg['cmd'] == 'setreppenslope'): + vars.rep_pen_slope = float(msg['data']) + emit('from_server', {'cmd': 'setlabelreppenslope', 'data': msg['data']}, broadcast=True) + settingschanged() + refresh_settings() + elif(msg['cmd'] == 'setreppenrange'): + vars.rep_pen_range = float(msg['data']) + emit('from_server', {'cmd': 'setlabelreppenrange', 'data': msg['data']}, broadcast=True) + settingschanged() + refresh_settings() elif(msg['cmd'] == 'setoutput'): vars.genamt = int(msg['data']) emit('from_server', {'cmd': 'setlabeloutput', 'data': msg['data']}, broadcast=True) @@ -2151,6 +2141,8 @@ def savesettings(): js["top_k"] = vars.top_k js["tfs"] = vars.tfs js["rep_pen"] = vars.rep_pen + js["rep_pen_slope"] = vars.rep_pen_slope + js["rep_pen_range"] = vars.rep_pen_range js["genamt"] = vars.genamt js["max_length"] = vars.max_length js["ikgen"] = vars.ikgen @@ -2206,6 +2198,10 @@ def loadsettings(): vars.tfs = js["tfs"] if("rep_pen" in js): vars.rep_pen = js["rep_pen"] + if("rep_pen_slope" in js): + vars.rep_pen_slope = js["rep_pen_slope"] + if("rep_pen_range" in js): + vars.rep_pen_range = js["rep_pen_range"] if("genamt" in js): vars.genamt = js["genamt"] if("max_length" in js): @@ -2285,6 +2281,10 @@ def loadmodelsettings(): vars.tfs = js["tfs"] if("rep_pen" in js): vars.rep_pen = js["rep_pen"] + if("rep_pen_slope" in js): + vars.rep_pen_slope = js["rep_pen_slope"] + if("rep_pen_range" in js): + vars.rep_pen_range = js["rep_pen_range"] if("adventure" in js): vars.adventure = js["adventure"] if("chatmode" in js): @@ -2958,6 +2958,8 @@ def sendtocolab(txt, min, max): 'min': min, 'max': max, 'rep_pen': vars.rep_pen, + 'rep_pen_slope': vars.rep_pen_slope, + 'rep_pen_range': vars.rep_pen_range, 'temperature': vars.temp, 'top_p': vars.top_p, 'top_k': vars.top_k, @@ -3099,6 +3101,8 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): tfs=vars.tfs, numseqs=vars.numseqs, repetition_penalty=vars.rep_pen, + rpslope=vars.rep_pen_slope, + rprange=vars.rep_pen_range, soft_embeddings=vars.sp, soft_tokens=soft_tokens, ) @@ -3281,6 +3285,8 @@ def refresh_settings(): emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True) emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True) emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True) + emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True) + emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True) emit('from_server', {'cmd': 'updateoutlen', 'data': vars.genamt}, broadcast=True) emit('from_server', {'cmd': 'updatetknmax', 'data': vars.max_length}, broadcast=True) emit('from_server', {'cmd': 'updatenumseq', 'data': vars.numseqs}, broadcast=True) diff --git a/bridge.lua b/bridge.lua index da63d63f..b46977c5 100644 --- a/bridge.lua +++ b/bridge.lua @@ -867,6 +867,8 @@ return function(_python, _bridged) ---@field settopk integer ---@field settfs number ---@field setreppen number + ---@field setreppenslope number + ---@field setreppenrange number ---@field settknmax integer ---@field setwidepth integer ---@field setuseprompt boolean @@ -881,6 +883,8 @@ return function(_python, _bridged) ---@field top_k integer ---@field tfs number ---@field reppen number + ---@field reppenslope number + ---@field reppenrange number ---@field tknmax integer ---@field widepth integer ---@field useprompt boolean diff --git a/gensettings.py b/gensettings.py index a55ae82c..2c1c1cc1 100644 --- a/gensettings.py +++ b/gensettings.py @@ -52,6 +52,28 @@ gensettingstf = [{ "step": 0.01, "default": 1.1, "tooltip": "Used to penalize words that were already generated or belong to the context (Going over 1.2 breaks 6B models)." + }, + { + "uitype": "slider", + "unit": "int", + "label": "Rep Penalty Range", + "id": "setreppenrange", + "min": 0, + "max": 4096, + "step": 4, + "default": 0, + "tooltip": "Repetition penalty range. If set higher than 0, only applies repetition penalty to the last few tokens of your story rather than applying it to the entire story. This slider controls the amount of tokens at the end of your story to apply it to." + }, + { + "uitype": "slider", + "unit": "float", + "label": "Rep Penalty Slope", + "id": "setreppenslope", + "min": 0.0, + "max": 10.0, + "step": 0.1, + "default": 0.0, + "tooltip": "Repetition penalty slope. If BOTH this setting and Rep Penalty Range are set higher than 0, will use sigmoid interpolation to apply repetition penalty more strongly on tokens that are closer to the end of your story. This setting controls the tension of the sigmoid curve; higher settings will result in the repetition penalty difference between the start and end of your story being more apparent. Setting this to 1 uses linear interpolation; setting this to 0 disables interpolation." }, { "uitype": "slider", diff --git a/static/application.js b/static/application.js index 14d7de37..7e7352fb 100644 --- a/static/application.js +++ b/static/application.js @@ -1991,6 +1991,14 @@ $(document).ready(function(){ // Send current rep pen value to input $("#setreppen").val(parseFloat(msg.data)); $("#setreppencur").html(msg.data); + } else if(msg.cmd == "updatereppenslope") { + // Send current rep pen value to input + $("#setreppenslope").val(parseFloat(msg.data)); + $("#setreppenslopecur").html(msg.data); + } else if(msg.cmd == "updatereppenrange") { + // Send current rep pen value to input + $("#setreppenrange").val(parseFloat(msg.data)); + $("#setreppenrangecur").html(msg.data); } else if(msg.cmd == "updateoutlen") { // Send current output amt value to input $("#setoutput").val(parseInt(msg.data)); @@ -2018,6 +2026,12 @@ $(document).ready(function(){ } else if(msg.cmd == "setlabelreppen") { // Update setting label with value from server $("#setreppencur").html(msg.data); + } else if(msg.cmd == "setlabelreppenslope") { + // Update setting label with value from server + $("#setreppenslopecur").html(msg.data); + } else if(msg.cmd == "setlabelreppenrange") { + // Update setting label with value from server + $("#setreppenrangecur").html(msg.data); } else if(msg.cmd == "setlabeloutput") { // Update setting label with value from server $("#setoutputcur").html(msg.data); diff --git a/templates/index.html b/templates/index.html index 07438174..c66ed052 100644 --- a/templates/index.html +++ b/templates/index.html @@ -17,7 +17,7 @@ - + diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 3b3f48e7..653f8cf1 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1,3 +1,32 @@ +''' +This file is AGPL-licensed. + +Some of the code in this file is from Clover Edition: +https://github.com/cloveranon/Clover-Edition/blob/master/aidungeon/gpt2generator.py + +The license for Clover Edition is shown below: + +Copyright (c) 2019 Nick Walton + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +''' + import multiprocessing from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar import progressbar @@ -33,6 +62,8 @@ def settings_callback() -> dict: "top_k": 0, "tfs": 1.0, "repetition_penalty": 1.0, + "rpslope": 0.0, + "rprange": 0, } def started_compiling_callback() -> None: @@ -78,26 +109,40 @@ def __batch_xmap(shard_dim=1): return inner -def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty): +def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange): ''' This gets called by generate_loop_fn to apply repetition penalty to the 1D array logits using the provided 1D array of tokens to penalize ''' tokens = np.minimum(tokens, params["n_vocab"]-1) # https://github.com/google/jax/issues/3774 + rpslope = np.int32(rpslope) + rprange = np.int32(rprange) + clipped_rprange = rprange if rprange > 0 else tokens.shape[-1] + penalty_arange = np.roll(np.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]), generated_index, axis=-1) # Make a new array with the same length as the tokens array but with # each element replaced by the value at the corresponding index in the # logits array; e.g. # if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1], # then penalty_logits will be [77, 5, 3, 98, 3, 98, 5] penalty_logits = np.take(logits, tokens) + # Repetition penalty slope + if rpslope != 0.0 and rprange > 0: + _penalty = (penalty_arange/(rprange - 1)) * 2 - 1 + _penalty = (rpslope * _penalty) / (1 + np.abs(_penalty) * (rpslope - 1)) + _penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1) + repetition_penalty = _penalty # Divide positive values by repetition_penalty and multiply negative # values by repetition_penalty (the academic publication that described # this technique actually just only divided, but that would cause tokens # with negative logits to become more likely, which is obviously wrong) penalty_logits = np.where( - penalty_logits > 0, - penalty_logits/repetition_penalty, - penalty_logits*repetition_penalty, + penalty_arange >= 0, + np.where( + penalty_logits > 0, + penalty_logits/repetition_penalty, + penalty_logits*repetition_penalty, + ), + penalty_logits, ) # Finally, put those penalized logit values back into their original # positions in the logits array @@ -202,25 +247,46 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): # probability distribution) return jax.random.categorical(key, logits, -1).astype(np.uint32) -def apply_repetition_penalty_static(logits, tokens, repetition_penalty): +def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange): ''' This gets called by generate_loop_fn to apply repetition penalty to the 1D array logits using the provided 1D array of tokens to penalize ''' + rpslope = jnp.int32(rpslope) + rprange = jnp.int32(rprange) + clipped_rprange = jax.lax.cond(rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange) + penalty_arange = jnp.roll(jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]), generated_index, axis=-1) # Make a new array with the same length as the tokens array but with # each element replaced by the value at the corresponding index in the # logits array; e.g. # if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1], # then penalty_logits will be [77, 5, 3, 98, 3, 98, 5] penalty_logits = jnp.take(logits, tokens) + # Repetition penalty slope + def apply_slope(carry): + repetition_penalty, rprange = carry + _penalty = (penalty_arange/(rprange - 1)) * 2 - 1 + _penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1)) + _penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1) + return _penalty + repetition_penalty = jax.lax.cond( + (rpslope != 0.0) & (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash + apply_slope, + lambda carry: jnp.full(tokens.shape, carry[0]), + (repetition_penalty, rprange), + ) # Divide positive values by repetition_penalty and multiply negative # values by repetition_penalty (the academic publication that described # this technique actually just only divided, but that would cause tokens # with negative logits to become more likely, which is obviously wrong) penalty_logits = jnp.where( - penalty_logits > 0, - penalty_logits/repetition_penalty, - penalty_logits*repetition_penalty, + penalty_arange >= 0, + jnp.where( + penalty_logits > 0, + penalty_logits/repetition_penalty, + penalty_logits*repetition_penalty, + ), + penalty_logits, ) # Finally, put those penalized logit values back into their original # positions in the logits array @@ -325,7 +391,7 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): pad_token_id = 50256 -def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_options): +def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_index, gen_length, rpslope, rprange, sampler_options): numseqs = numseqs_aux.shape[0] gi = data[0][1] def sample_loop_fn(carry): @@ -339,7 +405,11 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_op logits = apply_repetition_penalty_dynamic( logits, generated, - repetition_penalty + repetition_penalty, + generated_index, + gen_length, + rpslope, + rprange, ) # Remove any tokens in the badwords list by setting # their logits to negative infinity which effectively @@ -401,6 +471,8 @@ class PenalizingCausalTransformer(CausalTransformer): initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs)) # Get repetition penalty from the arguments repetition_penalty = sampler_options.pop('repetition_penalty', None) + rpslope = sampler_options.pop('rpslope', None) + rprange = sampler_options.pop('rprange', None) # This is the main generation loop def generate_loop_fn(carry): # Unpack current generate_loop_fn state @@ -427,7 +499,11 @@ class PenalizingCausalTransformer(CausalTransformer): logits = apply_repetition_penalty_static( logits, generated, - repetition_penalty + repetition_penalty, + generated_index, + gen_length, + rpslope, + rprange, ) # Remove any tokens in the badwords list by setting # their logits to negative infinity which effectively @@ -586,7 +662,9 @@ class PenalizingCausalTransformer(CausalTransformer): sample_data[i][2] = logits[i] sampler_options = settings_callback() repetition_penalty = sampler_options.pop("repetition_penalty", 1.0) - sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options) + rpslope = sampler_options.pop("rpslope", 0.0) + rprange = sampler_options.pop("rprange", 0) + sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, params["seq"] + n_generated, gen_length, rpslope, rprange, sampler_options) n_generated += 1 for i in range(numseqs): generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1)) @@ -659,6 +737,8 @@ def infer_static( top_k=0, tfs=1.0, repetition_penalty=1.0, + rpslope=0.0, + rprange=0, numseqs=1, gen_len=80, soft_embeddings: Optional[np.array] = None, @@ -679,6 +759,8 @@ def infer_static( "top_p": top_p * np.ones(total_batch), "tfs": tfs * np.ones(total_batch), "repetition_penalty": repetition_penalty * np.ones(total_batch), + "rpslope": rpslope * np.ones(total_batch), + "rprange": np.full(total_batch, rprange, dtype=np.uint32), "top_k": np.full(total_batch, top_k, dtype=np.uint32) } output = network.generate_static( diff --git a/warpers.py b/warpers.py new file mode 100644 index 00000000..07670f6d --- /dev/null +++ b/warpers.py @@ -0,0 +1,100 @@ +''' +This file is AGPL-licensed. + +Some of the code in this file is from Clover Edition: +https://github.com/cloveranon/Clover-Edition/blob/master/aidungeon/gpt2generator.py + +The license for Clover Edition is shown below: + +Copyright (c) 2019 Nick Walton + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +''' + +import torch +from transformers import LogitsWarper, LogitsProcessor + + +class AdvancedRepetitionPenaltyLogitsProcessor(LogitsProcessor): + def __init__(self, *args, **kwargs): + pass + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + self.penalty_range = int(self.penalty_range) + clipped_penalty_range = min(input_ids.shape[-1], self.penalty_range) + + if self.penalty != 1.0: + if self.penalty_range > 0: + if clipped_penalty_range < input_ids.shape[1]: + input_ids = input_ids[..., -clipped_penalty_range:] + + if self.penalty_slope != 0: + _penalty = (torch.arange(self.penalty_range, dtype=scores.dtype, device=scores.device)/(self.penalty_range - 1)) * 2. - 1 + _penalty = (self.penalty_slope * _penalty) / (1 + torch.abs(_penalty) * (self.penalty_slope - 1)) + _penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (self.penalty - 1) + self.penalty = _penalty[..., -clipped_penalty_range:] + + score = torch.gather(scores, 1, input_ids) + score = torch.where(score <= 0, score * self.penalty, score / self.penalty) + scores.scatter_(1, input_ids, score) + + return scores + + +class TailFreeLogitsWarper(LogitsWarper): + + def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + tfs = float(tfs) + if tfs < 0 or tfs > 1.0: + raise ValueError(f"`tfs` has to be a float > 0 and < 1, but is {tfs}") + self.tfs = tfs + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if self.filter_value >= 1.0: + return scores + sorted_logits, sorted_indices = torch.sort(scores, descending=True) + probs = sorted_logits.softmax(dim=-1) + + # Compute second derivative normalized CDF + d2 = probs.diff().diff().abs() + normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True) + normalized_d2_cdf = normalized_d2.cumsum(dim=-1) + + # Remove tokens with CDF value above the threshold (token with 0 are kept) + sorted_indices_to_remove = normalized_d2_cdf > self.tfs + + # Centre the distribution around the cutoff as in the original implementation of the algorithm + sorted_indices_to_remove = torch.cat( + ( + torch.zeros(scores.shape[0], 1, dtype=torch.bool, device=scores.device), + sorted_indices_to_remove, + torch.ones(scores.shape[0], 1, dtype=torch.bool, device=scores.device), + ), + dim=-1, + ) + + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep + sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 + + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores