Merge pull request #106 from VE-FORBRYDERNE/typical

Typical sampling
This commit is contained in:
henk717 2022-03-28 00:14:09 +02:00 committed by GitHub
commit 77ae893f4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 167 additions and 12 deletions

View File

@ -154,6 +154,7 @@ class vars:
top_p = 0.9 # Default generator top_p top_p = 0.9 # Default generator top_p
top_k = 0 # Default generator top_k top_k = 0 # Default generator top_k
tfs = 1.0 # Default generator tfs (tail-free sampling) tfs = 1.0 # Default generator tfs (tail-free sampling)
typical = 1.0 # Default generator typical sampling threshold
numseqs = 1 # Number of sequences to ask the generator to create numseqs = 1 # Number of sequences to ask the generator to create
gamestarted = False # Whether the game has started (disables UI elements) gamestarted = False # Whether the game has started (disables UI elements)
gamesaved = True # Whether or not current game is saved gamesaved = True # Whether or not current game is saved
@ -499,6 +500,8 @@ def loadmodelsettings():
vars.top_k = js["top_k"] vars.top_k = js["top_k"]
if("tfs" in js): if("tfs" in js):
vars.tfs = js["tfs"] vars.tfs = js["tfs"]
if("typical" in js):
vars.typical = js["typical"]
if("rep_pen" in js): if("rep_pen" in js):
vars.rep_pen = js["rep_pen"] vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js): if("rep_pen_slope" in js):
@ -534,6 +537,7 @@ def savesettings():
js["top_p"] = vars.top_p js["top_p"] = vars.top_p
js["top_k"] = vars.top_k js["top_k"] = vars.top_k
js["tfs"] = vars.tfs js["tfs"] = vars.tfs
js["typical"] = vars.typical
js["rep_pen"] = vars.rep_pen js["rep_pen"] = vars.rep_pen
js["rep_pen_slope"] = vars.rep_pen_slope js["rep_pen_slope"] = vars.rep_pen_slope
js["rep_pen_range"] = vars.rep_pen_range js["rep_pen_range"] = vars.rep_pen_range
@ -600,6 +604,8 @@ def loadsettings():
vars.top_k = js["top_k"] vars.top_k = js["top_k"]
if("tfs" in js): if("tfs" in js):
vars.tfs = js["tfs"] vars.tfs = js["tfs"]
if("typical" in js):
vars.typical = js["typical"]
if("rep_pen" in js): if("rep_pen" in js):
vars.rep_pen = js["rep_pen"] vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js): if("rep_pen_slope" in js):
@ -1172,7 +1178,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
# Patch transformers to use our custom logit warpers # Patch transformers to use our custom logit warpers
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper
def dynamic_processor_wrap(cls, field_name, var_name, cond=None): def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
old_call = cls.__call__ old_call = cls.__call__
@ -1194,6 +1200,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0) dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0) dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0) dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0) dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__ RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__
RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__ RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__
@ -1239,6 +1246,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TemperatureLogitsWarper(temperature=0.5)) warper_list.append(TemperatureLogitsWarper(temperature=0.5))
return warper_list return warper_list
@ -1540,6 +1548,7 @@ else:
"temp": float(vars.temp), "temp": float(vars.temp),
"top_k": int(vars.top_k), "top_k": int(vars.top_k),
"tfs": float(vars.tfs), "tfs": float(vars.tfs),
"typical": float(vars.typical),
"repetition_penalty": float(vars.rep_pen), "repetition_penalty": float(vars.rep_pen),
"rpslope": float(vars.rep_pen_slope), "rpslope": float(vars.rep_pen_slope),
"rprange": int(vars.rep_pen_range), "rprange": int(vars.rep_pen_range),
@ -1901,6 +1910,7 @@ def lua_has_setting(setting):
"settopp", "settopp",
"settopk", "settopk",
"settfs", "settfs",
"settypical",
"setreppen", "setreppen",
"setreppenslope", "setreppenslope",
"setreppenrange", "setreppenrange",
@ -1919,6 +1929,7 @@ def lua_has_setting(setting):
"topk", "topk",
"top_k", "top_k",
"tfs", "tfs",
"typical",
"reppen", "reppen",
"reppenslope", "reppenslope",
"reppenrange", "reppenrange",
@ -1952,6 +1963,7 @@ def lua_get_setting(setting):
if(setting in ("settopp", "topp", "top_p")): return vars.top_p if(setting in ("settopp", "topp", "top_p")): return vars.top_p
if(setting in ("settopk", "topk", "top_k")): return vars.top_k if(setting in ("settopk", "topk", "top_k")): return vars.top_k
if(setting in ("settfs", "tfs")): return vars.tfs if(setting in ("settfs", "tfs")): return vars.tfs
if(setting in ("settypical", "typical")): return vars.typical
if(setting in ("setreppen", "reppen")): return vars.rep_pen if(setting in ("setreppen", "reppen")): return vars.rep_pen
if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope
if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range
@ -1986,6 +1998,7 @@ def lua_set_setting(setting, v):
if(setting in ("settopp", "topp")): vars.top_p = v if(setting in ("settopp", "topp")): vars.top_p = v
if(setting in ("settopk", "topk")): vars.top_k = v if(setting in ("settopk", "topk")): vars.top_k = v
if(setting in ("settfs", "tfs")): vars.tfs = v if(setting in ("settfs", "tfs")): vars.tfs = v
if(setting in ("settypical", "typical")): vars.typical = v
if(setting in ("setreppen", "reppen")): vars.rep_pen = v if(setting in ("setreppen", "reppen")): vars.rep_pen = v
if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v
if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v
@ -2382,6 +2395,11 @@ def get_message(msg):
emit('from_server', {'cmd': 'setlabeltfs', 'data': msg['data']}, broadcast=True) emit('from_server', {'cmd': 'setlabeltfs', 'data': msg['data']}, broadcast=True)
settingschanged() settingschanged()
refresh_settings() refresh_settings()
elif(msg['cmd'] == 'settypical'):
vars.typical = float(msg['data'])
emit('from_server', {'cmd': 'setlabeltypical', 'data': msg['data']}, broadcast=True)
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setreppen'): elif(msg['cmd'] == 'setreppen'):
vars.rep_pen = float(msg['data']) vars.rep_pen = float(msg['data'])
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True) emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True)
@ -3442,6 +3460,7 @@ def sendtocolab(txt, min, max):
'top_p': vars.top_p, 'top_p': vars.top_p,
'top_k': vars.top_k, 'top_k': vars.top_k,
'tfs': vars.tfs, 'tfs': vars.tfs,
'typical': vars.typical,
'numseqs': vars.numseqs, 'numseqs': vars.numseqs,
'retfultxt': False 'retfultxt': False
} }
@ -3578,6 +3597,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
top_p=vars.top_p, top_p=vars.top_p,
top_k=vars.top_k, top_k=vars.top_k,
tfs=vars.tfs, tfs=vars.tfs,
typical=vars.typical,
numseqs=vars.numseqs, numseqs=vars.numseqs,
repetition_penalty=vars.rep_pen, repetition_penalty=vars.rep_pen,
rpslope=vars.rep_pen_slope, rpslope=vars.rep_pen_slope,
@ -3763,6 +3783,7 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatetopp', 'data': vars.top_p}, broadcast=True) emit('from_server', {'cmd': 'updatetopp', 'data': vars.top_p}, broadcast=True)
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True) emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True)
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True) emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True)
emit('from_server', {'cmd': 'updatetypical', 'data': vars.typical}, broadcast=True)
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True) emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True)
emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True) emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True)
emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True) emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True)
@ -4341,6 +4362,7 @@ def oairequest(txt, min, max):
'top_p': vars.top_p, 'top_p': vars.top_p,
'top_k': vars.top_k, 'top_k': vars.top_k,
'tfs': vars.tfs, 'tfs': vars.tfs,
'typical': vars.typical,
'repetition_penalty': vars.rep_pen, 'repetition_penalty': vars.rep_pen,
'repetition_penalty_slope': vars.rep_pen_slope, 'repetition_penalty_slope': vars.rep_pen_slope,
'repetition_penalty_range': vars.rep_pen_range, 'repetition_penalty_range': vars.rep_pen_range,

View File

@ -866,6 +866,7 @@ return function(_python, _bridged)
---@field settopp number ---@field settopp number
---@field settopk integer ---@field settopk integer
---@field settfs number ---@field settfs number
---@field settypical number
---@field setreppen number ---@field setreppen number
---@field setreppenslope number ---@field setreppenslope number
---@field setreppenrange number ---@field setreppenrange number
@ -882,6 +883,7 @@ return function(_python, _bridged)
---@field top_p number ---@field top_p number
---@field top_k integer ---@field top_k integer
---@field tfs number ---@field tfs number
---@field typical number
---@field reppen number ---@field reppen number
---@field reppenslope number ---@field reppenslope number
---@field reppenrange number ---@field reppenrange number

View File

@ -51,8 +51,19 @@ gensettingstf = [
"min": 0.0, "min": 0.0,
"max": 1.0, "max": 1.0,
"step": 0.05, "step": 0.05,
"default": 0.0, "default": 1.0,
"tooltip": "Alternative sampling method; it is recommended to disable top_p and top_k (set top_p to 1 and top_k to 0) if using this. 0.95 is thought to be a good value. (Put this value on 1 to disable its effect)" "tooltip": "Alternative sampling method; it is recommended to disable top_p and top_k (set top_p to 1 and top_k to 0) if using this. 0.95 is thought to be a good value. (Put this value on 1 to disable its effect)"
},
{
"uitype": "slider",
"unit": "float",
"label": "Typical Sampling",
"id": "settypical",
"min": 0.0,
"max": 1.0,
"step": 0.05,
"default": 1.0,
"tooltip": "Alternative sampling method described in the paper \"Typical Decoding for Natural Language Generation\" (10.48550/ARXIV.2202.00666). The paper suggests 0.2 as a good value for this setting. Set this setting to 1 to disable its effect."
}, },
{ {
"uitype": "slider", "uitype": "slider",

View File

@ -2041,6 +2041,10 @@ $(document).ready(function(){
// Send current tfs value to input // Send current tfs value to input
$("#settfs").val(parseFloat(msg.data)); $("#settfs").val(parseFloat(msg.data));
$("#settfscur").html(msg.data); $("#settfscur").html(msg.data);
} else if(msg.cmd == "updatetypical") {
// Send current typical value to input
$("#settypical").val(parseFloat(msg.data));
$("#settypicalcur").html(msg.data);
} else if(msg.cmd == "updatereppen") { } else if(msg.cmd == "updatereppen") {
// Send current rep pen value to input // Send current rep pen value to input
$("#setreppen").val(parseFloat(msg.data)); $("#setreppen").val(parseFloat(msg.data));
@ -2077,6 +2081,9 @@ $(document).ready(function(){
} else if(msg.cmd == "setlabeltfs") { } else if(msg.cmd == "setlabeltfs") {
// Update setting label with value from server // Update setting label with value from server
$("#settfscur").html(msg.data); $("#settfscur").html(msg.data);
} else if(msg.cmd == "setlabeltypical") {
// Update setting label with value from server
$("#settypicalcur").html(msg.data);
} else if(msg.cmd == "setlabelreppen") { } else if(msg.cmd == "setlabelreppen") {
// Update setting label with value from server // Update setting label with value from server
$("#setreppencur").html(msg.data); $("#setreppencur").html(msg.data);

View File

@ -17,7 +17,7 @@
<script src="static/bootstrap.min.js"></script> <script src="static/bootstrap.min.js"></script>
<script src="static/bootstrap-toggle.min.js"></script> <script src="static/bootstrap-toggle.min.js"></script>
<script src="static/rangy-core.min.js"></script> <script src="static/rangy-core.min.js"></script>
<script src="static/application.js?ver=1.17a"></script> <script src="static/application.js?ver=1.17b"></script>
</head> </head>
<body> <body>
<input type="file" id="remote-save-select" accept="application/json" style="display:none"> <input type="file" id="remote-save-select" accept="application/json" style="display:none">

View File

@ -67,6 +67,7 @@ def settings_callback() -> dict:
"temp": 0.5, "temp": 0.5,
"top_k": 0, "top_k": 0,
"tfs": 1.0, "tfs": 1.0,
"typical": 1.0,
"repetition_penalty": 1.0, "repetition_penalty": 1.0,
"rpslope": 0.0, "rpslope": 0.0,
"rprange": 0, "rprange": 0,
@ -155,11 +156,11 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat
logits[tokens] = penalty_logits logits[tokens] = penalty_logits
return logits return logits
def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0):
''' '''
This gets called by generate_loop_fn to apply a series of 4 filters This gets called by generate_loop_fn to apply a series of 5 filters
to the logits (top-k, then top-p, then TFS, then temperature) before to the logits (top-k, then top-p, then TFS, then typical, then temperature)
picking one token using the modified logits before picking one token using the modified logits
''' '''
# Top-k (keep only the k tokens with the highest logits and remove # Top-k (keep only the k tokens with the highest logits and remove
# the rest, by setting their logits to negative infinity) # the rest, by setting their logits to negative infinity)
@ -246,6 +247,37 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
return np.where(indices_to_remove, -np.inf, logits) return np.where(indices_to_remove, -np.inf, logits)
if tfs < 1.0: if tfs < 1.0:
logits = tail_free_filter(logits) logits = tail_free_filter(logits)
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
def typical_filter(logits):
# Compute softmax probabilities and the natural logarithms of them
probs = jax.nn.softmax(logits)
with np.errstate(divide="ignore"):
log_probs = np.log(probs)
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
# in the set of softmax probabilities of the logits
neg_entropy = (probs * log_probs).sum(axis=-1, keepdims=True)
# Determine absolute difference between the negative entropy and the
# log probabilities
entropy_deviation = np.abs(neg_entropy - log_probs)
# Keep certain tokens such that the sum of the entropy_deviation of the
# kept tokens is the smallest possible value such that the sum of the
# softmax probabilities of the kept tokens is at least the threshold
# value (by sorting the tokens in ascending order of entropy_deviation
# and then keeping the smallest possible number of tokens from the
# beginning such that sum of softmax probabilities is at or above the
# threshold)
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
sorted_indices_to_remove = np.cumsum(sorted_logits, axis=-1) >= typical
sorted_indices_to_remove = np.roll(sorted_indices_to_remove, 1, axis=-1)
sorted_indices_to_remove[0] = False
# Unsort and remove
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(entropy_deviation),
sorted_indices_to_remove,
)
return np.where(indices_to_remove, -jnp.inf, logits)
if typical < 1.0:
logits = typical_filter(logits)
# Temperature (just divide the logits by the temperature) # Temperature (just divide the logits by the temperature)
logits /= temp logits /= temp
# Finally, pick one token using the softmax thingy again (it gives # Finally, pick one token using the softmax thingy again (it gives
@ -298,11 +330,11 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate
# positions in the logits array # positions in the logits array
return logits.at[tokens].set(penalty_logits) return logits.at[tokens].set(penalty_logits)
def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0):
''' '''
This gets called by generate_loop_fn to apply a series of 4 filters This gets called by generate_loop_fn to apply a series of 5 filters
to the logits (top-k, then top-p, then TFS, then temperature) before to the logits (top-k, then top-p, then TFS, then typical, then temperature)
picking one token using the modified logits before picking one token using the modified logits
''' '''
# Top-k (keep only the k tokens with the highest logits and remove # Top-k (keep only the k tokens with the highest logits and remove
# the rest, by setting their logits to negative infinity) # the rest, by setting their logits to negative infinity)
@ -386,6 +418,35 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
) )
return jnp.where(indices_to_remove, -jnp.inf, logits) return jnp.where(indices_to_remove, -jnp.inf, logits)
logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits) logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits)
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
def typical_filter(logits):
# Compute softmax probabilities and the natural logarithms of them
probs = jax.nn.softmax(logits)
log_probs = jnp.log(probs)
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
# in the set of softmax probabilities of the logits
neg_entropy = (probs * log_probs).sum(axis=-1, keepdims=True)
# Determine absolute difference between the negative entropy and the
# log probabilities
entropy_deviation = jnp.abs(neg_entropy - log_probs)
# Keep certain tokens such that the sum of the entropy_deviation of the
# kept tokens is the smallest possible value such that the sum of the
# softmax probabilities of the kept tokens is at least the threshold
# value (by sorting the tokens in ascending order of entropy_deviation
# and then keeping the smallest possible number of tokens from the
# beginning such that sum of softmax probabilities is at or above the
# threshold)
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
sorted_indices_to_remove = jnp.cumsum(sorted_logits, axis=-1) >= typical
sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=-1)
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
# Unsort and remove
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(entropy_deviation),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, logits)
logits = jax.lax.cond(typical < 1.0, typical_filter, lambda x: x, logits)
# Temperature (just divide the logits by the temperature) # Temperature (just divide the logits by the temperature)
def temp_filter(logits): def temp_filter(logits):
return logits / temp return logits / temp
@ -742,6 +803,7 @@ def infer_static(
temp=0.5, temp=0.5,
top_k=0, top_k=0,
tfs=1.0, tfs=1.0,
typical=1.0,
repetition_penalty=1.0, repetition_penalty=1.0,
rpslope=0.0, rpslope=0.0,
rprange=0, rprange=0,
@ -764,6 +826,7 @@ def infer_static(
"temp": temp * np.ones(total_batch), "temp": temp * np.ones(total_batch),
"top_p": top_p * np.ones(total_batch), "top_p": top_p * np.ones(total_batch),
"tfs": tfs * np.ones(total_batch), "tfs": tfs * np.ones(total_batch),
"typical": typical * np.ones(total_batch),
"repetition_penalty": repetition_penalty * np.ones(total_batch), "repetition_penalty": repetition_penalty * np.ones(total_batch),
"rpslope": rpslope * np.ones(total_batch), "rpslope": rpslope * np.ones(total_batch),
"rprange": np.full(total_batch, rprange, dtype=np.uint32), "rprange": np.full(total_batch, rprange, dtype=np.uint32),

View File

@ -62,7 +62,7 @@ class TailFreeLogitsWarper(LogitsWarper):
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
tfs = float(tfs) tfs = float(tfs)
if tfs < 0 or tfs > 1.0: if tfs < 0 or tfs > 1.0:
raise ValueError(f"`tfs` has to be a float > 0 and < 1, but is {tfs}") raise ValueError(f"`tfs` has to be a float >= 0 and <= 1, but is {tfs}")
self.tfs = tfs self.tfs = tfs
self.filter_value = filter_value self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep self.min_tokens_to_keep = min_tokens_to_keep
@ -98,3 +98,53 @@ class TailFreeLogitsWarper(LogitsWarper):
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value) scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores return scores
class TypicalLogitsWarper(LogitsWarper):
'''
Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf
'''
def __init__(self, typical: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
typical = float(typical)
if typical < 0 or typical > 1.0:
raise ValueError(f"`typical` has to be a float >= 0 and <= 1, but is {typical}")
self.typical = typical
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self.filter_value >= 1.0:
return scores
# Compute softmax probabilities and the natural logarithms of them
probs = scores.softmax(dim=-1)
log_probs = probs.log()
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
# in the set of softmax probabilities of the logits
neg_entropy = (probs * log_probs).sum(dim=-1, keepdim=True)
# Determine absolute difference between the negative entropy and the
# log probabilities
entropy_deviation = (neg_entropy - log_probs).abs()
# Keep certain tokens such that the sum of the entropy_deviation of the
# kept tokens is the smallest possible value such that the sum of the
# softmax probabilities of the kept tokens is at least the threshold
# value (by sorting the tokens in ascending order of entropy_deviation
# and then keeping the smallest possible number of tokens from the
# beginning such that sum of softmax probabilities is at or above the
# threshold)
_, sorted_indices = torch.sort(entropy_deviation)
sorted_logits = probs.gather(-1, sorted_indices)
sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= self.typical
sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dims=-1)
min_tokens_to_keep = max(self.min_tokens_to_keep, 1)
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., : min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores