Merge commit '8a38b258f497281af06fcb0c2559f382b419b938' into overhaul-merge

This commit is contained in:
Gnome Ann
2022-06-14 18:36:37 -04:00
6 changed files with 107 additions and 7 deletions

View File

@ -220,6 +220,7 @@ class vars:
temp = 0.5 # Default generator temperature
top_p = 0.9 # Default generator top_p
top_k = 0 # Default generator top_k
top_a = 0.0 # Default generator top-a
tfs = 1.0 # Default generator tfs (tail-free sampling)
typical = 1.0 # Default generator typical sampling threshold
numseqs = 1 # Number of sequences to ask the generator to create
@ -657,6 +658,8 @@ def loadmodelsettings():
vars.tfs = js["tfs"]
if("typical" in js):
vars.typical = js["typical"]
if("top_a" in js):
vars.top_a = js["top_a"]
if("rep_pen" in js):
vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js):
@ -693,6 +696,7 @@ def savesettings():
js["top_k"] = vars.top_k
js["tfs"] = vars.tfs
js["typical"] = vars.typical
js["top_a"] = vars.top_a
js["rep_pen"] = vars.rep_pen
js["rep_pen_slope"] = vars.rep_pen_slope
js["rep_pen_range"] = vars.rep_pen_range
@ -773,6 +777,8 @@ def processsettings(js):
vars.tfs = js["tfs"]
if("typical" in js):
vars.typical = js["typical"]
if("top_a" in js):
vars.top_a = js["top_a"]
if("rep_pen" in js):
vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js):
@ -1268,7 +1274,7 @@ def patch_transformers():
# Patch transformers to use our custom logit warpers
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper, TopALogitsWarper
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
old_call = cls.__call__
@ -1288,6 +1294,7 @@ def patch_transformers():
cls.__call__ = new_call
dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0)
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0)
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0)
@ -1334,6 +1341,7 @@ def patch_transformers():
def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList:
warper_list = LogitsProcessorList()
warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
@ -1962,6 +1970,7 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
"top_k": int(vars.top_k),
"tfs": float(vars.tfs),
"typical": float(vars.typical),
"top_a": float(vars.top_a),
"repetition_penalty": float(vars.rep_pen),
"rpslope": float(vars.rep_pen_slope),
"rprange": int(vars.rep_pen_range),
@ -2384,6 +2393,7 @@ def lua_has_setting(setting):
"settopk",
"settfs",
"settypical",
"settopa",
"setreppen",
"setreppenslope",
"setreppenrange",
@ -2403,6 +2413,7 @@ def lua_has_setting(setting):
"top_k",
"tfs",
"typical",
"topa",
"reppen",
"reppenslope",
"reppenrange",
@ -2437,6 +2448,7 @@ def lua_get_setting(setting):
if(setting in ("settopk", "topk", "top_k")): return vars.top_k
if(setting in ("settfs", "tfs")): return vars.tfs
if(setting in ("settypical", "typical")): return vars.typical
if(setting in ("settopa", "topa")): return vars.top_a
if(setting in ("setreppen", "reppen")): return vars.rep_pen
if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope
if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range
@ -2472,6 +2484,7 @@ def lua_set_setting(setting, v):
if(setting in ("settopk", "topk")): vars.top_k = v
if(setting in ("settfs", "tfs")): vars.tfs = v
if(setting in ("settypical", "typical")): vars.typical = v
if(setting in ("settopa", "topa")): vars.top_a = v
if(setting in ("setreppen", "reppen")): vars.rep_pen = v
if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v
if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v
@ -2862,6 +2875,11 @@ def get_message(msg):
emit('from_server', {'cmd': 'setlabeltypical', 'data': msg['data']}, broadcast=True)
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'settopa'):
vars.top_a = float(msg['data'])
emit('from_server', {'cmd': 'setlabeltopa', 'data': msg['data']}, broadcast=True)
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setreppen'):
vars.rep_pen = float(msg['data'])
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True)
@ -3988,6 +4006,7 @@ def sendtocolab(txt, min, max):
'top_k': vars.top_k,
'tfs': vars.tfs,
'typical': vars.typical,
'topa': vars.top_a,
'numseqs': vars.numseqs,
'retfultxt': False
}
@ -4125,6 +4144,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
top_k=vars.top_k,
tfs=vars.tfs,
typical=vars.typical,
top_a=vars.top_a,
numseqs=vars.numseqs,
repetition_penalty=vars.rep_pen,
rpslope=vars.rep_pen_slope,
@ -4311,6 +4331,7 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True)
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True)
emit('from_server', {'cmd': 'updatetypical', 'data': vars.typical}, broadcast=True)
emit('from_server', {'cmd': 'updatetopa', 'data': vars.top_a}, broadcast=True)
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True)
emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True)
emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True)