Top-A sampling
This commit is contained in:
parent
98c2aad072
commit
fdb2a7fa4c
23
aiserver.py
23
aiserver.py
|
@ -212,6 +212,7 @@ class vars:
|
|||
temp = 0.5 # Default generator temperature
|
||||
top_p = 0.9 # Default generator top_p
|
||||
top_k = 0 # Default generator top_k
|
||||
top_a = 0.0 # Default generator top-a
|
||||
tfs = 1.0 # Default generator tfs (tail-free sampling)
|
||||
typical = 1.0 # Default generator typical sampling threshold
|
||||
numseqs = 1 # Number of sequences to ask the generator to create
|
||||
|
@ -577,6 +578,8 @@ def loadmodelsettings():
|
|||
vars.tfs = js["tfs"]
|
||||
if("typical" in js):
|
||||
vars.typical = js["typical"]
|
||||
if("top_a" in js):
|
||||
vars.top_a = js["top_a"]
|
||||
if("rep_pen" in js):
|
||||
vars.rep_pen = js["rep_pen"]
|
||||
if("rep_pen_slope" in js):
|
||||
|
@ -613,6 +616,7 @@ def savesettings():
|
|||
js["top_k"] = vars.top_k
|
||||
js["tfs"] = vars.tfs
|
||||
js["typical"] = vars.typical
|
||||
js["top_a"] = vars.top_a
|
||||
js["rep_pen"] = vars.rep_pen
|
||||
js["rep_pen_slope"] = vars.rep_pen_slope
|
||||
js["rep_pen_range"] = vars.rep_pen_range
|
||||
|
@ -693,6 +697,8 @@ def processsettings(js):
|
|||
vars.tfs = js["tfs"]
|
||||
if("typical" in js):
|
||||
vars.typical = js["typical"]
|
||||
if("top_a" in js):
|
||||
vars.top_a = js["top_a"]
|
||||
if("rep_pen" in js):
|
||||
vars.rep_pen = js["rep_pen"]
|
||||
if("rep_pen_slope" in js):
|
||||
|
@ -1379,7 +1385,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||
|
||||
# Patch transformers to use our custom logit warpers
|
||||
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
|
||||
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper
|
||||
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper, TopALogitsWarper
|
||||
|
||||
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
|
||||
old_call = cls.__call__
|
||||
|
@ -1399,6 +1405,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||
cls.__call__ = new_call
|
||||
dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0)
|
||||
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
|
||||
dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0)
|
||||
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
|
||||
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
|
||||
dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0)
|
||||
|
@ -1445,6 +1452,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||
def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList:
|
||||
warper_list = LogitsProcessorList()
|
||||
warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
|
||||
warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
|
@ -1814,6 +1822,7 @@ else:
|
|||
"top_k": int(vars.top_k),
|
||||
"tfs": float(vars.tfs),
|
||||
"typical": float(vars.typical),
|
||||
"top_a": float(vars.top_a),
|
||||
"repetition_penalty": float(vars.rep_pen),
|
||||
"rpslope": float(vars.rep_pen_slope),
|
||||
"rprange": int(vars.rep_pen_range),
|
||||
|
@ -2176,6 +2185,7 @@ def lua_has_setting(setting):
|
|||
"settopk",
|
||||
"settfs",
|
||||
"settypical",
|
||||
"settopa",
|
||||
"setreppen",
|
||||
"setreppenslope",
|
||||
"setreppenrange",
|
||||
|
@ -2195,6 +2205,7 @@ def lua_has_setting(setting):
|
|||
"top_k",
|
||||
"tfs",
|
||||
"typical",
|
||||
"topa",
|
||||
"reppen",
|
||||
"reppenslope",
|
||||
"reppenrange",
|
||||
|
@ -2229,6 +2240,7 @@ def lua_get_setting(setting):
|
|||
if(setting in ("settopk", "topk", "top_k")): return vars.top_k
|
||||
if(setting in ("settfs", "tfs")): return vars.tfs
|
||||
if(setting in ("settypical", "typical")): return vars.typical
|
||||
if(setting in ("settopa", "topa")): return vars.top_a
|
||||
if(setting in ("setreppen", "reppen")): return vars.rep_pen
|
||||
if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope
|
||||
if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range
|
||||
|
@ -2264,6 +2276,7 @@ def lua_set_setting(setting, v):
|
|||
if(setting in ("settopk", "topk")): vars.top_k = v
|
||||
if(setting in ("settfs", "tfs")): vars.tfs = v
|
||||
if(setting in ("settypical", "typical")): vars.typical = v
|
||||
if(setting in ("settopa", "topa")): vars.top_a = v
|
||||
if(setting in ("setreppen", "reppen")): vars.rep_pen = v
|
||||
if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v
|
||||
if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v
|
||||
|
@ -2688,6 +2701,11 @@ def get_message(msg):
|
|||
emit('from_server', {'cmd': 'setlabeltypical', 'data': msg['data']}, broadcast=True)
|
||||
settingschanged()
|
||||
refresh_settings()
|
||||
elif(msg['cmd'] == 'settopa'):
|
||||
vars.top_a = float(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabeltopa', 'data': msg['data']}, broadcast=True)
|
||||
settingschanged()
|
||||
refresh_settings()
|
||||
elif(msg['cmd'] == 'setreppen'):
|
||||
vars.rep_pen = float(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True)
|
||||
|
@ -3748,6 +3766,7 @@ def sendtocolab(txt, min, max):
|
|||
'top_k': vars.top_k,
|
||||
'tfs': vars.tfs,
|
||||
'typical': vars.typical,
|
||||
'topa': vars.top_a,
|
||||
'numseqs': vars.numseqs,
|
||||
'retfultxt': False
|
||||
}
|
||||
|
@ -3885,6 +3904,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
|||
top_k=vars.top_k,
|
||||
tfs=vars.tfs,
|
||||
typical=vars.typical,
|
||||
top_a=vars.top_a,
|
||||
numseqs=vars.numseqs,
|
||||
repetition_penalty=vars.rep_pen,
|
||||
rpslope=vars.rep_pen_slope,
|
||||
|
@ -4071,6 +4091,7 @@ def refresh_settings():
|
|||
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatetypical', 'data': vars.typical}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatetopa', 'data': vars.top_a}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True)
|
||||
|
|
|
@ -867,6 +867,7 @@ return function(_python, _bridged)
|
|||
---@field settopk integer
|
||||
---@field settfs number
|
||||
---@field settypical number
|
||||
---@field settopa number
|
||||
---@field setreppen number
|
||||
---@field setreppenslope number
|
||||
---@field setreppenrange number
|
||||
|
@ -884,6 +885,7 @@ return function(_python, _bridged)
|
|||
---@field top_k integer
|
||||
---@field tfs number
|
||||
---@field typical number
|
||||
---@field topa number
|
||||
---@field reppen number
|
||||
---@field reppenslope number
|
||||
---@field reppenrange number
|
||||
|
|
|
@ -64,6 +64,17 @@ gensettingstf = [
|
|||
"step": 0.05,
|
||||
"default": 1.0,
|
||||
"tooltip": "Alternative sampling method described in the paper \"Typical Decoding for Natural Language Generation\" (10.48550/ARXIV.2202.00666). The paper suggests 0.2 as a good value for this setting. Set this setting to 1 to disable its effect."
|
||||
},
|
||||
{
|
||||
"uitype": "slider",
|
||||
"unit": "float",
|
||||
"label": "Top a Sampling",
|
||||
"id": "settopa",
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.01,
|
||||
"default": 0.0,
|
||||
"tooltip": "Alternative sampling method that reduces the randomness of the AI whenever the probability of one token is much higher than all the others. Higher values have a stronger effect. Set this setting to 0 to disable its effect."
|
||||
},
|
||||
{
|
||||
"uitype": "slider",
|
||||
|
|
|
@ -2096,6 +2096,10 @@ $(document).ready(function(){
|
|||
// Send current typical value to input
|
||||
$("#settypicalcur").val(msg.data);
|
||||
$("#settypical").val(parseFloat(msg.data)).trigger("change");
|
||||
} else if(msg.cmd == "updatetopa") {
|
||||
// Send current top a value to input
|
||||
$("#settopacur").val(msg.data);
|
||||
$("#settopa").val(parseFloat(msg.data)).trigger("change");
|
||||
} else if(msg.cmd == "updatereppen") {
|
||||
// Send current rep pen value to input
|
||||
$("#setreppencur").val(msg.data);
|
||||
|
@ -2135,6 +2139,9 @@ $(document).ready(function(){
|
|||
} else if(msg.cmd == "setlabeltypical") {
|
||||
// Update setting label with value from server
|
||||
$("#settypicalcur").val(msg.data);
|
||||
} else if(msg.cmd == "setlabeltypical") {
|
||||
// Update setting label with value from server
|
||||
$("#settopa").val(msg.data);
|
||||
} else if(msg.cmd == "setlabelreppen") {
|
||||
// Update setting label with value from server
|
||||
$("#setreppencur").val(msg.data);
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
<script src="static/bootstrap.min.js"></script>
|
||||
<script src="static/bootstrap-toggle.min.js"></script>
|
||||
<script src="static/rangy-core.min.js"></script>
|
||||
<script src="static/application.js?ver=1.18c"></script>
|
||||
<script src="static/application.js?ver=1.18d"></script>
|
||||
</head>
|
||||
<body>
|
||||
<input type="file" id="remote-save-select" accept="application/json" style="display:none">
|
||||
|
|
|
@ -70,6 +70,7 @@ def settings_callback() -> dict:
|
|||
"top_k": 0,
|
||||
"tfs": 1.0,
|
||||
"typical": 1.0,
|
||||
"top_a": 0.0,
|
||||
"repetition_penalty": 1.0,
|
||||
"rpslope": 0.0,
|
||||
"rprange": 0,
|
||||
|
@ -158,10 +159,10 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat
|
|||
logits[tokens] = penalty_logits
|
||||
return logits
|
||||
|
||||
def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0):
|
||||
def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||
'''
|
||||
This gets called by generate_loop_fn to apply a series of 5 filters
|
||||
to the logits (top-k, then top-p, then TFS, then typical, then temperature)
|
||||
This gets called by generate_loop_fn to apply a series of 6 filters
|
||||
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
||||
before picking one token using the modified logits
|
||||
'''
|
||||
# Top-k (keep only the k tokens with the highest logits and remove
|
||||
|
@ -182,6 +183,20 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
|
|||
return np.where(indices_to_remove, -np.inf, logits)
|
||||
if top_k > 0:
|
||||
logits = top_k_filter(logits)
|
||||
# Top-a (remove all tokens that have softmax probability less than
|
||||
# a*m^2 where m is the maximum softmax probability)
|
||||
def top_a_filter(logits):
|
||||
# Replace every element in the logits array
|
||||
# with e (Euler's number) to the power of that element, and divide
|
||||
# each element of the new array by the sum of the elements in the
|
||||
# new array
|
||||
probabilities = np.array(jax.nn.softmax(logits), copy=True)
|
||||
# Find the largest probability
|
||||
probs_max = probabilities.max()
|
||||
# Remove tokens
|
||||
return np.where(probabilities < probs_max * probs_max * top_a, -np.inf, logits)
|
||||
if top_a > 0.0:
|
||||
logits = top_a_filter(logits)
|
||||
# Top-p (after sorting the remaining tokens again in descending order of
|
||||
# logit, remove the ones that have cumulative softmax probability
|
||||
# greater than p)
|
||||
|
@ -332,10 +347,10 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate
|
|||
# positions in the logits array
|
||||
return logits.at[tokens].set(penalty_logits)
|
||||
|
||||
def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0):
|
||||
def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||
'''
|
||||
This gets called by generate_loop_fn to apply a series of 5 filters
|
||||
to the logits (top-k, then top-p, then TFS, then typical, then temperature)
|
||||
This gets called by generate_loop_fn to apply a series of 6 filters
|
||||
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
||||
before picking one token using the modified logits
|
||||
'''
|
||||
# Top-k (keep only the k tokens with the highest logits and remove
|
||||
|
@ -355,6 +370,19 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
|
|||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
||||
logits = jax.lax.cond(top_k > 0, top_k_filter, lambda x: x, logits)
|
||||
# Top-a (remove all tokens that have softmax probability less than
|
||||
# a*m^2 where m is the maximum softmax probability)
|
||||
def top_a_filter(logits):
|
||||
# Replace every element in the logits array
|
||||
# with e (Euler's number) to the power of that element, and divide
|
||||
# each element of the new array by the sum of the elements in the
|
||||
# new array
|
||||
probabilities = jax.nn.softmax(logits)
|
||||
# Find the largest probability
|
||||
probs_max = probabilities.max()
|
||||
# Remove tokens
|
||||
return jnp.where(probabilities < probs_max * probs_max * top_a, -jnp.inf, logits)
|
||||
logits = jax.lax.cond(top_a > 0.0, top_a_filter, lambda x: x, logits)
|
||||
# Top-p (after sorting the remaining tokens again in descending order of
|
||||
# logit, remove the ones that have cumulative softmax probability
|
||||
# greater than p)
|
||||
|
@ -806,6 +834,7 @@ def infer_static(
|
|||
top_k=0,
|
||||
tfs=1.0,
|
||||
typical=1.0,
|
||||
top_a=0.0,
|
||||
repetition_penalty=1.0,
|
||||
rpslope=0.0,
|
||||
rprange=0,
|
||||
|
@ -829,6 +858,7 @@ def infer_static(
|
|||
"top_p": top_p * np.ones(total_batch),
|
||||
"tfs": tfs * np.ones(total_batch),
|
||||
"typical": typical * np.ones(total_batch),
|
||||
"top_a": top_a * np.ones(total_batch),
|
||||
"repetition_penalty": repetition_penalty * np.ones(total_batch),
|
||||
"rpslope": rpslope * np.ones(total_batch),
|
||||
"rprange": np.full(total_batch, rprange, dtype=np.uint32),
|
||||
|
|
29
warpers.py
29
warpers.py
|
@ -148,3 +148,32 @@ class TypicalLogitsWarper(LogitsWarper):
|
|||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
|
||||
class TopALogitsWarper(LogitsWarper):
|
||||
def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
top_a = float(top_a)
|
||||
if top_a < 0 or top_a > 1.0:
|
||||
raise ValueError(f"`top_a` has to be a float >= 0 and <= 1, but is {top_a}")
|
||||
self.top_a = top_a
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.filter_value >= 1.0:
|
||||
return scores
|
||||
|
||||
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||||
probs = sorted_logits.softmax(dim=-1)
|
||||
|
||||
# Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept)
|
||||
probs_max = probs[..., 0, None]
|
||||
sorted_indices_to_remove = probs >= probs_max * probs_max * self.top_a
|
||||
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep
|
||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
|
Loading…
Reference in New Issue