Typical sampling

This commit is contained in:
Gnome Ann
2022-03-27 16:25:50 -04:00
parent e4c72ca2e5
commit 20e48b11d7
7 changed files with 166 additions and 12 deletions

View File

@ -154,6 +154,7 @@ class vars:
top_p = 0.9 # Default generator top_p
top_k = 0 # Default generator top_k
tfs = 1.0 # Default generator tfs (tail-free sampling)
typical = 1.0 # Default generator typical sampling threshold
numseqs = 1 # Number of sequences to ask the generator to create
gamestarted = False # Whether the game has started (disables UI elements)
gamesaved = True # Whether or not current game is saved
@ -499,6 +500,8 @@ def loadmodelsettings():
vars.top_k = js["top_k"]
if("tfs" in js):
vars.tfs = js["tfs"]
if("typical" in js):
vars.typical = js["typical"]
if("rep_pen" in js):
vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js):
@ -534,6 +537,7 @@ def savesettings():
js["top_p"] = vars.top_p
js["top_k"] = vars.top_k
js["tfs"] = vars.tfs
js["typical"] = vars.typical
js["rep_pen"] = vars.rep_pen
js["rep_pen_slope"] = vars.rep_pen_slope
js["rep_pen_range"] = vars.rep_pen_range
@ -600,6 +604,8 @@ def loadsettings():
vars.top_k = js["top_k"]
if("tfs" in js):
vars.tfs = js["tfs"]
if("typical" in js):
vars.typical = js["typical"]
if("rep_pen" in js):
vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js):
@ -1172,7 +1178,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
# Patch transformers to use our custom logit warpers
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
old_call = cls.__call__
@ -1194,6 +1200,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__
RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__
@ -1239,6 +1246,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TemperatureLogitsWarper(temperature=0.5))
return warper_list
@ -1540,6 +1548,7 @@ else:
"temp": float(vars.temp),
"top_k": int(vars.top_k),
"tfs": float(vars.tfs),
"typical": float(vars.typical),
"repetition_penalty": float(vars.rep_pen),
"rpslope": float(vars.rep_pen_slope),
"rprange": int(vars.rep_pen_range),
@ -1901,6 +1910,7 @@ def lua_has_setting(setting):
"settopp",
"settopk",
"settfs",
"settypical",
"setreppen",
"setreppenslope",
"setreppenrange",
@ -1919,6 +1929,7 @@ def lua_has_setting(setting):
"topk",
"top_k",
"tfs",
"typical",
"reppen",
"reppenslope",
"reppenrange",
@ -1952,6 +1963,7 @@ def lua_get_setting(setting):
if(setting in ("settopp", "topp", "top_p")): return vars.top_p
if(setting in ("settopk", "topk", "top_k")): return vars.top_k
if(setting in ("settfs", "tfs")): return vars.tfs
if(setting in ("settypical", "typical")): return vars.typical
if(setting in ("setreppen", "reppen")): return vars.rep_pen
if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope
if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range
@ -1986,6 +1998,7 @@ def lua_set_setting(setting, v):
if(setting in ("settopp", "topp")): vars.top_p = v
if(setting in ("settopk", "topk")): vars.top_k = v
if(setting in ("settfs", "tfs")): vars.tfs = v
if(setting in ("settypical", "typical")): vars.typical = v
if(setting in ("setreppen", "reppen")): vars.rep_pen = v
if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v
if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v
@ -2382,6 +2395,11 @@ def get_message(msg):
emit('from_server', {'cmd': 'setlabeltfs', 'data': msg['data']}, broadcast=True)
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'settypical'):
vars.typical = float(msg['data'])
emit('from_server', {'cmd': 'setlabeltypical', 'data': msg['data']}, broadcast=True)
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setreppen'):
vars.rep_pen = float(msg['data'])
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True)
@ -3442,6 +3460,7 @@ def sendtocolab(txt, min, max):
'top_p': vars.top_p,
'top_k': vars.top_k,
'tfs': vars.tfs,
'typical': vars.typical,
'numseqs': vars.numseqs,
'retfultxt': False
}
@ -3578,6 +3597,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
top_p=vars.top_p,
top_k=vars.top_k,
tfs=vars.tfs,
typical=vars.typical,
numseqs=vars.numseqs,
repetition_penalty=vars.rep_pen,
rpslope=vars.rep_pen_slope,
@ -3763,6 +3783,7 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatetopp', 'data': vars.top_p}, broadcast=True)
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True)
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True)
emit('from_server', {'cmd': 'updatetypical', 'data': vars.typical}, broadcast=True)
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True)
emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True)
emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True)
@ -4341,6 +4362,7 @@ def oairequest(txt, min, max):
'top_p': vars.top_p,
'top_k': vars.top_k,
'tfs': vars.tfs,
'typical': vars.typical,
'repetition_penalty': vars.rep_pen,
'repetition_penalty_slope': vars.rep_pen_slope,
'repetition_penalty_range': vars.rep_pen_range,