Top-A sampling

This commit is contained in:
Gnome Ann
2022-06-10 22:28:20 -04:00
parent 98c2aad072
commit fdb2a7fa4c
7 changed files with 108 additions and 8 deletions

View File

@ -212,6 +212,7 @@ class vars:
temp = 0.5 # Default generator temperature
top_p = 0.9 # Default generator top_p
top_k = 0 # Default generator top_k
top_a = 0.0 # Default generator top-a
tfs = 1.0 # Default generator tfs (tail-free sampling)
typical = 1.0 # Default generator typical sampling threshold
numseqs = 1 # Number of sequences to ask the generator to create
@ -577,6 +578,8 @@ def loadmodelsettings():
vars.tfs = js["tfs"]
if("typical" in js):
vars.typical = js["typical"]
if("top_a" in js):
vars.top_a = js["top_a"]
if("rep_pen" in js):
vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js):
@ -613,6 +616,7 @@ def savesettings():
js["top_k"] = vars.top_k
js["tfs"] = vars.tfs
js["typical"] = vars.typical
js["top_a"] = vars.top_a
js["rep_pen"] = vars.rep_pen
js["rep_pen_slope"] = vars.rep_pen_slope
js["rep_pen_range"] = vars.rep_pen_range
@ -693,6 +697,8 @@ def processsettings(js):
vars.tfs = js["tfs"]
if("typical" in js):
vars.typical = js["typical"]
if("top_a" in js):
vars.top_a = js["top_a"]
if("rep_pen" in js):
vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js):
@ -1379,7 +1385,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
# Patch transformers to use our custom logit warpers
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper, TopALogitsWarper
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
old_call = cls.__call__
@ -1399,6 +1405,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
cls.__call__ = new_call
dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0)
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0)
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0)
@ -1445,6 +1452,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList:
warper_list = LogitsProcessorList()
warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
@ -1814,6 +1822,7 @@ else:
"top_k": int(vars.top_k),
"tfs": float(vars.tfs),
"typical": float(vars.typical),
"top_a": float(vars.top_a),
"repetition_penalty": float(vars.rep_pen),
"rpslope": float(vars.rep_pen_slope),
"rprange": int(vars.rep_pen_range),
@ -2176,6 +2185,7 @@ def lua_has_setting(setting):
"settopk",
"settfs",
"settypical",
"settopa",
"setreppen",
"setreppenslope",
"setreppenrange",
@ -2195,6 +2205,7 @@ def lua_has_setting(setting):
"top_k",
"tfs",
"typical",
"topa",
"reppen",
"reppenslope",
"reppenrange",
@ -2229,6 +2240,7 @@ def lua_get_setting(setting):
if(setting in ("settopk", "topk", "top_k")): return vars.top_k
if(setting in ("settfs", "tfs")): return vars.tfs
if(setting in ("settypical", "typical")): return vars.typical
if(setting in ("settopa", "topa")): return vars.top_a
if(setting in ("setreppen", "reppen")): return vars.rep_pen
if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope
if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range
@ -2264,6 +2276,7 @@ def lua_set_setting(setting, v):
if(setting in ("settopk", "topk")): vars.top_k = v
if(setting in ("settfs", "tfs")): vars.tfs = v
if(setting in ("settypical", "typical")): vars.typical = v
if(setting in ("settopa", "topa")): vars.top_a = v
if(setting in ("setreppen", "reppen")): vars.rep_pen = v
if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v
if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v
@ -2688,6 +2701,11 @@ def get_message(msg):
emit('from_server', {'cmd': 'setlabeltypical', 'data': msg['data']}, broadcast=True)
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'settopa'):
vars.top_a = float(msg['data'])
emit('from_server', {'cmd': 'setlabeltopa', 'data': msg['data']}, broadcast=True)
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setreppen'):
vars.rep_pen = float(msg['data'])
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True)
@ -3748,6 +3766,7 @@ def sendtocolab(txt, min, max):
'top_k': vars.top_k,
'tfs': vars.tfs,
'typical': vars.typical,
'topa': vars.top_a,
'numseqs': vars.numseqs,
'retfultxt': False
}
@ -3885,6 +3904,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
top_k=vars.top_k,
tfs=vars.tfs,
typical=vars.typical,
top_a=vars.top_a,
numseqs=vars.numseqs,
repetition_penalty=vars.rep_pen,
rpslope=vars.rep_pen_slope,
@ -4071,6 +4091,7 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True)
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True)
emit('from_server', {'cmd': 'updatetypical', 'data': vars.typical}, broadcast=True)
emit('from_server', {'cmd': 'updatetopa', 'data': vars.top_a}, broadcast=True)
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True)
emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True)
emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True)