Merge pull request #106 from VE-FORBRYDERNE/typical

Typical sampling
This commit is contained in:
henk717 2022-03-28 00:14:09 +02:00 committed by GitHub
commit 77ae893f4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 167 additions and 12 deletions

View File

@ -154,6 +154,7 @@ class vars:
top_p = 0.9 # Default generator top_p
top_k = 0 # Default generator top_k
tfs = 1.0 # Default generator tfs (tail-free sampling)
typical = 1.0 # Default generator typical sampling threshold
numseqs = 1 # Number of sequences to ask the generator to create
gamestarted = False # Whether the game has started (disables UI elements)
gamesaved = True # Whether or not current game is saved
@ -499,6 +500,8 @@ def loadmodelsettings():
vars.top_k = js["top_k"]
if("tfs" in js):
vars.tfs = js["tfs"]
if("typical" in js):
vars.typical = js["typical"]
if("rep_pen" in js):
vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js):
@ -534,6 +537,7 @@ def savesettings():
js["top_p"] = vars.top_p
js["top_k"] = vars.top_k
js["tfs"] = vars.tfs
js["typical"] = vars.typical
js["rep_pen"] = vars.rep_pen
js["rep_pen_slope"] = vars.rep_pen_slope
js["rep_pen_range"] = vars.rep_pen_range
@ -600,6 +604,8 @@ def loadsettings():
vars.top_k = js["top_k"]
if("tfs" in js):
vars.tfs = js["tfs"]
if("typical" in js):
vars.typical = js["typical"]
if("rep_pen" in js):
vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js):
@ -1172,7 +1178,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
# Patch transformers to use our custom logit warpers
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
old_call = cls.__call__
@ -1194,6 +1200,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__
RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__
@ -1239,6 +1246,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TemperatureLogitsWarper(temperature=0.5))
return warper_list
@ -1540,6 +1548,7 @@ else:
"temp": float(vars.temp),
"top_k": int(vars.top_k),
"tfs": float(vars.tfs),
"typical": float(vars.typical),
"repetition_penalty": float(vars.rep_pen),
"rpslope": float(vars.rep_pen_slope),
"rprange": int(vars.rep_pen_range),
@ -1901,6 +1910,7 @@ def lua_has_setting(setting):
"settopp",
"settopk",
"settfs",
"settypical",
"setreppen",
"setreppenslope",
"setreppenrange",
@ -1919,6 +1929,7 @@ def lua_has_setting(setting):
"topk",
"top_k",
"tfs",
"typical",
"reppen",
"reppenslope",
"reppenrange",
@ -1952,6 +1963,7 @@ def lua_get_setting(setting):
if(setting in ("settopp", "topp", "top_p")): return vars.top_p
if(setting in ("settopk", "topk", "top_k")): return vars.top_k
if(setting in ("settfs", "tfs")): return vars.tfs
if(setting in ("settypical", "typical")): return vars.typical
if(setting in ("setreppen", "reppen")): return vars.rep_pen
if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope
if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range
@ -1986,6 +1998,7 @@ def lua_set_setting(setting, v):
if(setting in ("settopp", "topp")): vars.top_p = v
if(setting in ("settopk", "topk")): vars.top_k = v
if(setting in ("settfs", "tfs")): vars.tfs = v
if(setting in ("settypical", "typical")): vars.typical = v
if(setting in ("setreppen", "reppen")): vars.rep_pen = v
if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v
if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v
@ -2382,6 +2395,11 @@ def get_message(msg):
emit('from_server', {'cmd': 'setlabeltfs', 'data': msg['data']}, broadcast=True)
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'settypical'):
vars.typical = float(msg['data'])
emit('from_server', {'cmd': 'setlabeltypical', 'data': msg['data']}, broadcast=True)
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setreppen'):
vars.rep_pen = float(msg['data'])
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True)
@ -3442,6 +3460,7 @@ def sendtocolab(txt, min, max):
'top_p': vars.top_p,
'top_k': vars.top_k,
'tfs': vars.tfs,
'typical': vars.typical,
'numseqs': vars.numseqs,
'retfultxt': False
}
@ -3578,6 +3597,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
top_p=vars.top_p,
top_k=vars.top_k,
tfs=vars.tfs,
typical=vars.typical,
numseqs=vars.numseqs,
repetition_penalty=vars.rep_pen,
rpslope=vars.rep_pen_slope,
@ -3763,6 +3783,7 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatetopp', 'data': vars.top_p}, broadcast=True)
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True)
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True)
emit('from_server', {'cmd': 'updatetypical', 'data': vars.typical}, broadcast=True)
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True)
emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True)
emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True)
@ -4341,6 +4362,7 @@ def oairequest(txt, min, max):
'top_p': vars.top_p,
'top_k': vars.top_k,
'tfs': vars.tfs,
'typical': vars.typical,
'repetition_penalty': vars.rep_pen,
'repetition_penalty_slope': vars.rep_pen_slope,
'repetition_penalty_range': vars.rep_pen_range,

View File

@ -866,6 +866,7 @@ return function(_python, _bridged)
---@field settopp number
---@field settopk integer
---@field settfs number
---@field settypical number
---@field setreppen number
---@field setreppenslope number
---@field setreppenrange number
@ -882,6 +883,7 @@ return function(_python, _bridged)
---@field top_p number
---@field top_k integer
---@field tfs number
---@field typical number
---@field reppen number
---@field reppenslope number
---@field reppenrange number

View File

@ -51,8 +51,19 @@ gensettingstf = [
"min": 0.0,
"max": 1.0,
"step": 0.05,
"default": 0.0,
"default": 1.0,
"tooltip": "Alternative sampling method; it is recommended to disable top_p and top_k (set top_p to 1 and top_k to 0) if using this. 0.95 is thought to be a good value. (Put this value on 1 to disable its effect)"
},
{
"uitype": "slider",
"unit": "float",
"label": "Typical Sampling",
"id": "settypical",
"min": 0.0,
"max": 1.0,
"step": 0.05,
"default": 1.0,
"tooltip": "Alternative sampling method described in the paper \"Typical Decoding for Natural Language Generation\" (10.48550/ARXIV.2202.00666). The paper suggests 0.2 as a good value for this setting. Set this setting to 1 to disable its effect."
},
{
"uitype": "slider",

View File

@ -2041,6 +2041,10 @@ $(document).ready(function(){
// Send current tfs value to input
$("#settfs").val(parseFloat(msg.data));
$("#settfscur").html(msg.data);
} else if(msg.cmd == "updatetypical") {
// Send current typical value to input
$("#settypical").val(parseFloat(msg.data));
$("#settypicalcur").html(msg.data);
} else if(msg.cmd == "updatereppen") {
// Send current rep pen value to input
$("#setreppen").val(parseFloat(msg.data));
@ -2077,6 +2081,9 @@ $(document).ready(function(){
} else if(msg.cmd == "setlabeltfs") {
// Update setting label with value from server
$("#settfscur").html(msg.data);
} else if(msg.cmd == "setlabeltypical") {
// Update setting label with value from server
$("#settypicalcur").html(msg.data);
} else if(msg.cmd == "setlabelreppen") {
// Update setting label with value from server
$("#setreppencur").html(msg.data);

View File

@ -17,7 +17,7 @@
<script src="static/bootstrap.min.js"></script>
<script src="static/bootstrap-toggle.min.js"></script>
<script src="static/rangy-core.min.js"></script>
<script src="static/application.js?ver=1.17a"></script>
<script src="static/application.js?ver=1.17b"></script>
</head>
<body>
<input type="file" id="remote-save-select" accept="application/json" style="display:none">

View File

@ -67,6 +67,7 @@ def settings_callback() -> dict:
"temp": 0.5,
"top_k": 0,
"tfs": 1.0,
"typical": 1.0,
"repetition_penalty": 1.0,
"rpslope": 0.0,
"rprange": 0,
@ -155,11 +156,11 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat
logits[tokens] = penalty_logits
return logits
def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0):
'''
This gets called by generate_loop_fn to apply a series of 4 filters
to the logits (top-k, then top-p, then TFS, then temperature) before
picking one token using the modified logits
This gets called by generate_loop_fn to apply a series of 5 filters
to the logits (top-k, then top-p, then TFS, then typical, then temperature)
before picking one token using the modified logits
'''
# Top-k (keep only the k tokens with the highest logits and remove
# the rest, by setting their logits to negative infinity)
@ -246,6 +247,37 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
return np.where(indices_to_remove, -np.inf, logits)
if tfs < 1.0:
logits = tail_free_filter(logits)
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
def typical_filter(logits):
# Compute softmax probabilities and the natural logarithms of them
probs = jax.nn.softmax(logits)
with np.errstate(divide="ignore"):
log_probs = np.log(probs)
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
# in the set of softmax probabilities of the logits
neg_entropy = (probs * log_probs).sum(axis=-1, keepdims=True)
# Determine absolute difference between the negative entropy and the
# log probabilities
entropy_deviation = np.abs(neg_entropy - log_probs)
# Keep certain tokens such that the sum of the entropy_deviation of the
# kept tokens is the smallest possible value such that the sum of the
# softmax probabilities of the kept tokens is at least the threshold
# value (by sorting the tokens in ascending order of entropy_deviation
# and then keeping the smallest possible number of tokens from the
# beginning such that sum of softmax probabilities is at or above the
# threshold)
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
sorted_indices_to_remove = np.cumsum(sorted_logits, axis=-1) >= typical
sorted_indices_to_remove = np.roll(sorted_indices_to_remove, 1, axis=-1)
sorted_indices_to_remove[0] = False
# Unsort and remove
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(entropy_deviation),
sorted_indices_to_remove,
)
return np.where(indices_to_remove, -jnp.inf, logits)
if typical < 1.0:
logits = typical_filter(logits)
# Temperature (just divide the logits by the temperature)
logits /= temp
# Finally, pick one token using the softmax thingy again (it gives
@ -298,11 +330,11 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate
# positions in the logits array
return logits.at[tokens].set(penalty_logits)
def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0):
'''
This gets called by generate_loop_fn to apply a series of 4 filters
to the logits (top-k, then top-p, then TFS, then temperature) before
picking one token using the modified logits
This gets called by generate_loop_fn to apply a series of 5 filters
to the logits (top-k, then top-p, then TFS, then typical, then temperature)
before picking one token using the modified logits
'''
# Top-k (keep only the k tokens with the highest logits and remove
# the rest, by setting their logits to negative infinity)
@ -386,6 +418,35 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
)
return jnp.where(indices_to_remove, -jnp.inf, logits)
logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits)
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
def typical_filter(logits):
# Compute softmax probabilities and the natural logarithms of them
probs = jax.nn.softmax(logits)
log_probs = jnp.log(probs)
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
# in the set of softmax probabilities of the logits
neg_entropy = (probs * log_probs).sum(axis=-1, keepdims=True)
# Determine absolute difference between the negative entropy and the
# log probabilities
entropy_deviation = jnp.abs(neg_entropy - log_probs)
# Keep certain tokens such that the sum of the entropy_deviation of the
# kept tokens is the smallest possible value such that the sum of the
# softmax probabilities of the kept tokens is at least the threshold
# value (by sorting the tokens in ascending order of entropy_deviation
# and then keeping the smallest possible number of tokens from the
# beginning such that sum of softmax probabilities is at or above the
# threshold)
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
sorted_indices_to_remove = jnp.cumsum(sorted_logits, axis=-1) >= typical
sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=-1)
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
# Unsort and remove
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(entropy_deviation),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, logits)
logits = jax.lax.cond(typical < 1.0, typical_filter, lambda x: x, logits)
# Temperature (just divide the logits by the temperature)
def temp_filter(logits):
return logits / temp
@ -742,6 +803,7 @@ def infer_static(
temp=0.5,
top_k=0,
tfs=1.0,
typical=1.0,
repetition_penalty=1.0,
rpslope=0.0,
rprange=0,
@ -764,6 +826,7 @@ def infer_static(
"temp": temp * np.ones(total_batch),
"top_p": top_p * np.ones(total_batch),
"tfs": tfs * np.ones(total_batch),
"typical": typical * np.ones(total_batch),
"repetition_penalty": repetition_penalty * np.ones(total_batch),
"rpslope": rpslope * np.ones(total_batch),
"rprange": np.full(total_batch, rprange, dtype=np.uint32),

View File

@ -62,7 +62,7 @@ class TailFreeLogitsWarper(LogitsWarper):
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
tfs = float(tfs)
if tfs < 0 or tfs > 1.0:
raise ValueError(f"`tfs` has to be a float > 0 and < 1, but is {tfs}")
raise ValueError(f"`tfs` has to be a float >= 0 and <= 1, but is {tfs}")
self.tfs = tfs
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
@ -98,3 +98,53 @@ class TailFreeLogitsWarper(LogitsWarper):
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
class TypicalLogitsWarper(LogitsWarper):
'''
Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf
'''
def __init__(self, typical: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
typical = float(typical)
if typical < 0 or typical > 1.0:
raise ValueError(f"`typical` has to be a float >= 0 and <= 1, but is {typical}")
self.typical = typical
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self.filter_value >= 1.0:
return scores
# Compute softmax probabilities and the natural logarithms of them
probs = scores.softmax(dim=-1)
log_probs = probs.log()
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
# in the set of softmax probabilities of the logits
neg_entropy = (probs * log_probs).sum(dim=-1, keepdim=True)
# Determine absolute difference between the negative entropy and the
# log probabilities
entropy_deviation = (neg_entropy - log_probs).abs()
# Keep certain tokens such that the sum of the entropy_deviation of the
# kept tokens is the smallest possible value such that the sum of the
# softmax probabilities of the kept tokens is at least the threshold
# value (by sorting the tokens in ascending order of entropy_deviation
# and then keeping the smallest possible number of tokens from the
# beginning such that sum of softmax probabilities is at or above the
# threshold)
_, sorted_indices = torch.sort(entropy_deviation)
sorted_logits = probs.gather(-1, sorted_indices)
sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= self.typical
sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dims=-1)
min_tokens_to_keep = max(self.min_tokens_to_keep, 1)
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., : min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores