Repetition penalty slope and range

This commit is contained in:
Gnome Ann 2022-01-24 15:30:38 -05:00
parent e69265cb4f
commit 3f18888eec
7 changed files with 288 additions and 60 deletions

View File

@ -23,6 +23,7 @@ import packaging
import contextlib
import traceback
import threading
from collections.abc import Iterable
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
import requests
@ -100,6 +101,8 @@ class vars:
genamt = 80 # Amount of text for each action to generate
ikgen = 200 # Number of characters for InferKit to generate
rep_pen = 1.1 # Default generator repetition_penalty
rep_pen_slope = 0.0 # Default generator repetition penalty slope
rep_pen_range = 0 # Default generator repetition penalty range
temp = 0.5 # Default generator temperature
top_p = 0.9 # Default generator top_p
top_k = 0 # Default generator top_k
@ -696,65 +699,32 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
# Patch transformers to use our custom logit warpers
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
old_call = cls.__call__
def new_call(self, *args, **kwargs):
setattr(self, field_name, getattr(vars, var_name))
if(not isinstance(field_name, str) and isinstance(field_name, Iterable)):
conds = []
for f, v in zip(field_name, var_name):
conds.append(getattr(vars, v))
setattr(self, f, conds[-1])
else:
conds = getattr(vars, var_name)
setattr(self, field_name, conds)
assert len(args) == 2
if(cond is None or cond(getattr(vars, var_name))):
if(cond is None or cond(conds)):
return old_call(self, *args, **kwargs)
return args[1]
cls.__call__ = new_call
dynamic_processor_wrap(RepetitionPenaltyLogitsProcessor, "penalty", "rep_pen", cond=lambda x: x != 1.0)
dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0)
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__
RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__
class TailFreeLogitsWarper(LogitsWarper):
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
tfs = float(tfs)
if tfs < 0 or tfs > 1.0:
raise ValueError(f"`tfs` has to be a float > 0 and < 1, but is {tfs}")
self.tfs = tfs
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
self.tfs = vars.tfs
if self.filter_value >= 1.0:
return scores
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
probs = sorted_logits.softmax(dim=-1)
# Compute second derivative normalized CDF
d2 = probs.diff().diff().abs()
normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True)
normalized_d2_cdf = normalized_d2.cumsum(dim=-1)
# Remove tokens with CDF value above the threshold (token with 0 are kept)
sorted_indices_to_remove = normalized_d2_cdf > self.tfs
# Centre the distribution around the cutoff as in the original implementation of the algorithm
sorted_indices_to_remove = torch.cat(
(
torch.zeros(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
sorted_indices_to_remove,
torch.ones(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
),
dim=-1,
)
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
class LuaLogitsProcessor(LogitsProcessor):
def __init__(self):
@ -1072,6 +1042,8 @@ else:
"top_k": int(vars.top_k),
"tfs": float(vars.tfs),
"repetition_penalty": float(vars.rep_pen),
"rpslope": float(vars.rep_pen_slope),
"rprange": int(vars.rep_pen_range),
}
# If we're running Colab or OAI, we still need a tokenizer.
@ -1418,6 +1390,8 @@ def lua_has_setting(setting):
"settopk",
"settfs",
"setreppen",
"setreppenslope",
"setreppenrange",
"settknmax",
"setwidepth",
"setuseprompt",
@ -1433,6 +1407,8 @@ def lua_has_setting(setting):
"top_k",
"tfs",
"reppen",
"reppenslope",
"reppenrange",
"tknmax",
"widepth",
"useprompt",
@ -1464,6 +1440,8 @@ def lua_get_setting(setting):
if(setting in ("settopk", "topk", "top_k")): return vars.top_k
if(setting in ("settfs", "tfs")): return vars.tfs
if(setting in ("setreppen", "reppen")): return vars.rep_pen
if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope
if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range
if(setting in ("settknmax", "tknmax")): return vars.max_length
if(setting == "anotedepth"): return vars.andepth
if(setting in ("setwidepth", "widepth")): return vars.widepth
@ -1495,6 +1473,8 @@ def lua_set_setting(setting, v):
if(setting in ("settopk", "topk")): vars.top_k = v
if(setting in ("settfs", "tfs")): vars.tfs = v
if(setting in ("setreppen", "reppen")): vars.rep_pen = v
if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v
if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v
if(setting in ("settknmax", "tknmax")): vars.max_length = v; return True
if(setting == "anotedepth"): vars.andepth = v; return True
if(setting in ("setwidepth", "widepth")): vars.widepth = v; return True
@ -1881,6 +1861,16 @@ def get_message(msg):
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True)
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setreppenslope'):
vars.rep_pen_slope = float(msg['data'])
emit('from_server', {'cmd': 'setlabelreppenslope', 'data': msg['data']}, broadcast=True)
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setreppenrange'):
vars.rep_pen_range = float(msg['data'])
emit('from_server', {'cmd': 'setlabelreppenrange', 'data': msg['data']}, broadcast=True)
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setoutput'):
vars.genamt = int(msg['data'])
emit('from_server', {'cmd': 'setlabeloutput', 'data': msg['data']}, broadcast=True)
@ -2151,6 +2141,8 @@ def savesettings():
js["top_k"] = vars.top_k
js["tfs"] = vars.tfs
js["rep_pen"] = vars.rep_pen
js["rep_pen_slope"] = vars.rep_pen_slope
js["rep_pen_range"] = vars.rep_pen_range
js["genamt"] = vars.genamt
js["max_length"] = vars.max_length
js["ikgen"] = vars.ikgen
@ -2206,6 +2198,10 @@ def loadsettings():
vars.tfs = js["tfs"]
if("rep_pen" in js):
vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js):
vars.rep_pen_slope = js["rep_pen_slope"]
if("rep_pen_range" in js):
vars.rep_pen_range = js["rep_pen_range"]
if("genamt" in js):
vars.genamt = js["genamt"]
if("max_length" in js):
@ -2285,6 +2281,10 @@ def loadmodelsettings():
vars.tfs = js["tfs"]
if("rep_pen" in js):
vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js):
vars.rep_pen_slope = js["rep_pen_slope"]
if("rep_pen_range" in js):
vars.rep_pen_range = js["rep_pen_range"]
if("adventure" in js):
vars.adventure = js["adventure"]
if("chatmode" in js):
@ -2958,6 +2958,8 @@ def sendtocolab(txt, min, max):
'min': min,
'max': max,
'rep_pen': vars.rep_pen,
'rep_pen_slope': vars.rep_pen_slope,
'rep_pen_range': vars.rep_pen_range,
'temperature': vars.temp,
'top_p': vars.top_p,
'top_k': vars.top_k,
@ -3099,6 +3101,8 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
tfs=vars.tfs,
numseqs=vars.numseqs,
repetition_penalty=vars.rep_pen,
rpslope=vars.rep_pen_slope,
rprange=vars.rep_pen_range,
soft_embeddings=vars.sp,
soft_tokens=soft_tokens,
)
@ -3281,6 +3285,8 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True)
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True)
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True)
emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True)
emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True)
emit('from_server', {'cmd': 'updateoutlen', 'data': vars.genamt}, broadcast=True)
emit('from_server', {'cmd': 'updatetknmax', 'data': vars.max_length}, broadcast=True)
emit('from_server', {'cmd': 'updatenumseq', 'data': vars.numseqs}, broadcast=True)

View File

@ -867,6 +867,8 @@ return function(_python, _bridged)
---@field settopk integer
---@field settfs number
---@field setreppen number
---@field setreppenslope number
---@field setreppenrange number
---@field settknmax integer
---@field setwidepth integer
---@field setuseprompt boolean
@ -881,6 +883,8 @@ return function(_python, _bridged)
---@field top_k integer
---@field tfs number
---@field reppen number
---@field reppenslope number
---@field reppenrange number
---@field tknmax integer
---@field widepth integer
---@field useprompt boolean

View File

@ -52,6 +52,28 @@ gensettingstf = [{
"step": 0.01,
"default": 1.1,
"tooltip": "Used to penalize words that were already generated or belong to the context (Going over 1.2 breaks 6B models)."
},
{
"uitype": "slider",
"unit": "int",
"label": "Rep Penalty Range",
"id": "setreppenrange",
"min": 0,
"max": 4096,
"step": 4,
"default": 0,
"tooltip": "Repetition penalty range. If set higher than 0, only applies repetition penalty to the last few tokens of your story rather than applying it to the entire story. This slider controls the amount of tokens at the end of your story to apply it to."
},
{
"uitype": "slider",
"unit": "float",
"label": "Rep Penalty Slope",
"id": "setreppenslope",
"min": 0.0,
"max": 10.0,
"step": 0.1,
"default": 0.0,
"tooltip": "Repetition penalty slope. If BOTH this setting and Rep Penalty Range are set higher than 0, will use sigmoid interpolation to apply repetition penalty more strongly on tokens that are closer to the end of your story. This setting controls the tension of the sigmoid curve; higher settings will result in the repetition penalty difference between the start and end of your story being more apparent. Setting this to 1 uses linear interpolation; setting this to 0 disables interpolation."
},
{
"uitype": "slider",

View File

@ -1991,6 +1991,14 @@ $(document).ready(function(){
// Send current rep pen value to input
$("#setreppen").val(parseFloat(msg.data));
$("#setreppencur").html(msg.data);
} else if(msg.cmd == "updatereppenslope") {
// Send current rep pen value to input
$("#setreppenslope").val(parseFloat(msg.data));
$("#setreppenslopecur").html(msg.data);
} else if(msg.cmd == "updatereppenrange") {
// Send current rep pen value to input
$("#setreppenrange").val(parseFloat(msg.data));
$("#setreppenrangecur").html(msg.data);
} else if(msg.cmd == "updateoutlen") {
// Send current output amt value to input
$("#setoutput").val(parseInt(msg.data));
@ -2018,6 +2026,12 @@ $(document).ready(function(){
} else if(msg.cmd == "setlabelreppen") {
// Update setting label with value from server
$("#setreppencur").html(msg.data);
} else if(msg.cmd == "setlabelreppenslope") {
// Update setting label with value from server
$("#setreppenslopecur").html(msg.data);
} else if(msg.cmd == "setlabelreppenrange") {
// Update setting label with value from server
$("#setreppenrangecur").html(msg.data);
} else if(msg.cmd == "setlabeloutput") {
// Update setting label with value from server
$("#setoutputcur").html(msg.data);

View File

@ -17,7 +17,7 @@
<script src="static/bootstrap.min.js"></script>
<script src="static/bootstrap-toggle.min.js"></script>
<script src="static/rangy-core.min.js"></script>
<script src="static/application.js?ver=1.16.4ac"></script>
<script src="static/application.js?ver=1.17"></script>
</head>
<body>
<input type="file" id="remote-save-select" accept="application/json" style="display:none">

View File

@ -1,3 +1,32 @@
'''
This file is AGPL-licensed.
Some of the code in this file is from Clover Edition:
https://github.com/cloveranon/Clover-Edition/blob/master/aidungeon/gpt2generator.py
The license for Clover Edition is shown below:
Copyright (c) 2019 Nick Walton
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
'''
import multiprocessing
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
import progressbar
@ -33,6 +62,8 @@ def settings_callback() -> dict:
"top_k": 0,
"tfs": 1.0,
"repetition_penalty": 1.0,
"rpslope": 0.0,
"rprange": 0,
}
def started_compiling_callback() -> None:
@ -78,26 +109,40 @@ def __batch_xmap(shard_dim=1):
return inner
def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty):
def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
'''
This gets called by generate_loop_fn to apply repetition penalty
to the 1D array logits using the provided 1D array of tokens to penalize
'''
tokens = np.minimum(tokens, params["n_vocab"]-1) # https://github.com/google/jax/issues/3774
rpslope = np.int32(rpslope)
rprange = np.int32(rprange)
clipped_rprange = rprange if rprange > 0 else tokens.shape[-1]
penalty_arange = np.roll(np.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]), generated_index, axis=-1)
# Make a new array with the same length as the tokens array but with
# each element replaced by the value at the corresponding index in the
# logits array; e.g.
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
penalty_logits = np.take(logits, tokens)
# Repetition penalty slope
if rpslope != 0.0 and rprange > 0:
_penalty = (penalty_arange/(rprange - 1)) * 2 - 1
_penalty = (rpslope * _penalty) / (1 + np.abs(_penalty) * (rpslope - 1))
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
repetition_penalty = _penalty
# Divide positive values by repetition_penalty and multiply negative
# values by repetition_penalty (the academic publication that described
# this technique actually just only divided, but that would cause tokens
# with negative logits to become more likely, which is obviously wrong)
penalty_logits = np.where(
penalty_logits > 0,
penalty_logits/repetition_penalty,
penalty_logits*repetition_penalty,
penalty_arange >= 0,
np.where(
penalty_logits > 0,
penalty_logits/repetition_penalty,
penalty_logits*repetition_penalty,
),
penalty_logits,
)
# Finally, put those penalized logit values back into their original
# positions in the logits array
@ -202,25 +247,46 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
# probability distribution)
return jax.random.categorical(key, logits, -1).astype(np.uint32)
def apply_repetition_penalty_static(logits, tokens, repetition_penalty):
def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
'''
This gets called by generate_loop_fn to apply repetition penalty
to the 1D array logits using the provided 1D array of tokens to penalize
'''
rpslope = jnp.int32(rpslope)
rprange = jnp.int32(rprange)
clipped_rprange = jax.lax.cond(rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange)
penalty_arange = jnp.roll(jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]), generated_index, axis=-1)
# Make a new array with the same length as the tokens array but with
# each element replaced by the value at the corresponding index in the
# logits array; e.g.
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
penalty_logits = jnp.take(logits, tokens)
# Repetition penalty slope
def apply_slope(carry):
repetition_penalty, rprange = carry
_penalty = (penalty_arange/(rprange - 1)) * 2 - 1
_penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1))
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
return _penalty
repetition_penalty = jax.lax.cond(
(rpslope != 0.0) & (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash
apply_slope,
lambda carry: jnp.full(tokens.shape, carry[0]),
(repetition_penalty, rprange),
)
# Divide positive values by repetition_penalty and multiply negative
# values by repetition_penalty (the academic publication that described
# this technique actually just only divided, but that would cause tokens
# with negative logits to become more likely, which is obviously wrong)
penalty_logits = jnp.where(
penalty_logits > 0,
penalty_logits/repetition_penalty,
penalty_logits*repetition_penalty,
penalty_arange >= 0,
jnp.where(
penalty_logits > 0,
penalty_logits/repetition_penalty,
penalty_logits*repetition_penalty,
),
penalty_logits,
)
# Finally, put those penalized logit values back into their original
# positions in the logits array
@ -325,7 +391,7 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
pad_token_id = 50256
def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_options):
def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_index, gen_length, rpslope, rprange, sampler_options):
numseqs = numseqs_aux.shape[0]
gi = data[0][1]
def sample_loop_fn(carry):
@ -339,7 +405,11 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_op
logits = apply_repetition_penalty_dynamic(
logits,
generated,
repetition_penalty
repetition_penalty,
generated_index,
gen_length,
rpslope,
rprange,
)
# Remove any tokens in the badwords list by setting
# their logits to negative infinity which effectively
@ -401,6 +471,8 @@ class PenalizingCausalTransformer(CausalTransformer):
initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs))
# Get repetition penalty from the arguments
repetition_penalty = sampler_options.pop('repetition_penalty', None)
rpslope = sampler_options.pop('rpslope', None)
rprange = sampler_options.pop('rprange', None)
# This is the main generation loop
def generate_loop_fn(carry):
# Unpack current generate_loop_fn state
@ -427,7 +499,11 @@ class PenalizingCausalTransformer(CausalTransformer):
logits = apply_repetition_penalty_static(
logits,
generated,
repetition_penalty
repetition_penalty,
generated_index,
gen_length,
rpslope,
rprange,
)
# Remove any tokens in the badwords list by setting
# their logits to negative infinity which effectively
@ -586,7 +662,9 @@ class PenalizingCausalTransformer(CausalTransformer):
sample_data[i][2] = logits[i]
sampler_options = settings_callback()
repetition_penalty = sampler_options.pop("repetition_penalty", 1.0)
sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options)
rpslope = sampler_options.pop("rpslope", 0.0)
rprange = sampler_options.pop("rprange", 0)
sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, params["seq"] + n_generated, gen_length, rpslope, rprange, sampler_options)
n_generated += 1
for i in range(numseqs):
generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1))
@ -659,6 +737,8 @@ def infer_static(
top_k=0,
tfs=1.0,
repetition_penalty=1.0,
rpslope=0.0,
rprange=0,
numseqs=1,
gen_len=80,
soft_embeddings: Optional[np.array] = None,
@ -679,6 +759,8 @@ def infer_static(
"top_p": top_p * np.ones(total_batch),
"tfs": tfs * np.ones(total_batch),
"repetition_penalty": repetition_penalty * np.ones(total_batch),
"rpslope": rpslope * np.ones(total_batch),
"rprange": np.full(total_batch, rprange, dtype=np.uint32),
"top_k": np.full(total_batch, top_k, dtype=np.uint32)
}
output = network.generate_static(

100
warpers.py Normal file
View File

@ -0,0 +1,100 @@
'''
This file is AGPL-licensed.
Some of the code in this file is from Clover Edition:
https://github.com/cloveranon/Clover-Edition/blob/master/aidungeon/gpt2generator.py
The license for Clover Edition is shown below:
Copyright (c) 2019 Nick Walton
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
'''
import torch
from transformers import LogitsWarper, LogitsProcessor
class AdvancedRepetitionPenaltyLogitsProcessor(LogitsProcessor):
def __init__(self, *args, **kwargs):
pass
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
self.penalty_range = int(self.penalty_range)
clipped_penalty_range = min(input_ids.shape[-1], self.penalty_range)
if self.penalty != 1.0:
if self.penalty_range > 0:
if clipped_penalty_range < input_ids.shape[1]:
input_ids = input_ids[..., -clipped_penalty_range:]
if self.penalty_slope != 0:
_penalty = (torch.arange(self.penalty_range, dtype=scores.dtype, device=scores.device)/(self.penalty_range - 1)) * 2. - 1
_penalty = (self.penalty_slope * _penalty) / (1 + torch.abs(_penalty) * (self.penalty_slope - 1))
_penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (self.penalty - 1)
self.penalty = _penalty[..., -clipped_penalty_range:]
score = torch.gather(scores, 1, input_ids)
score = torch.where(score <= 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, input_ids, score)
return scores
class TailFreeLogitsWarper(LogitsWarper):
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
tfs = float(tfs)
if tfs < 0 or tfs > 1.0:
raise ValueError(f"`tfs` has to be a float > 0 and < 1, but is {tfs}")
self.tfs = tfs
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self.filter_value >= 1.0:
return scores
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
probs = sorted_logits.softmax(dim=-1)
# Compute second derivative normalized CDF
d2 = probs.diff().diff().abs()
normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True)
normalized_d2_cdf = normalized_d2.cumsum(dim=-1)
# Remove tokens with CDF value above the threshold (token with 0 are kept)
sorted_indices_to_remove = normalized_d2_cdf > self.tfs
# Centre the distribution around the cutoff as in the original implementation of the algorithm
sorted_indices_to_remove = torch.cat(
(
torch.zeros(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
sorted_indices_to_remove,
torch.ones(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
),
dim=-1,
)
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores