top_k and tfs support by Frogging101

Adds top_k and tfs support, also fixes a SocketIO error.
This commit is contained in:
henk717 2021-08-19 14:47:57 +02:00
parent caee12eae0
commit 72bfc417da
3 changed files with 88 additions and 5 deletions

View File

@ -64,6 +64,8 @@ class vars:
rep_pen = 1.1 # Default generator repetition_penalty
temp = 0.5 # Default generator temperature
top_p = 0.9 # Default generator top_p
top_k = 0 # Default generator top_k
tfs = 0.0 # Default generator tfs (tail-free sampling)
numseqs = 1 # Number of sequences to ask the generator to create
gamestarted = False # Whether the game has started (disables UI elements)
prompt = "" # Prompt
@ -449,6 +451,14 @@ def get_message(msg):
vars.top_p = float(msg['data'])
emit('from_server', {'cmd': 'setlabeltopp', 'data': msg['data']})
settingschanged()
elif(msg['cmd'] == 'settopk'):
vars.top_k = int(msg['data'])
emit('from_server', {'cmd': 'setlabeltopk', 'data': msg['data']})
settingschanged()
elif(msg['cmd'] == 'settfs'):
vars.tfs = float(msg['data'])
emit('from_server', {'cmd': 'setlabeltfs', 'data': msg['data']})
settingschanged()
elif(msg['cmd'] == 'setreppen'):
vars.rep_pen = float(msg['data'])
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']})
@ -588,6 +598,8 @@ def savesettings():
js["andepth"] = vars.andepth
js["temp"] = vars.temp
js["top_p"] = vars.top_p
js["top_k"] = vars.top_k
js["tfs"] = vars.tfs
js["rep_pen"] = vars.rep_pen
js["genamt"] = vars.genamt
js["max_length"] = vars.max_length
@ -623,6 +635,10 @@ def loadsettings():
vars.temp = js["temp"]
if("top_p" in js):
vars.top_p = js["top_p"]
if("top_k" in js):
vars.top_k = js["top_k"]
if("tfs" in js):
vars.tfs = js["tfs"]
if("rep_pen" in js):
vars.rep_pen = js["rep_pen"]
if("genamt" in js):
@ -922,13 +938,19 @@ def generate(txt, min, max):
# Submit input text to generator
try:
top_p = vars.top_p if vars.top_p > 0.0 else None
top_k = vars.top_k if vars.top_k > 0 else None
tfs = vars.tfs if vars.tfs > 0.0 else None
genout = generator(
txt,
do_sample=True,
min_length=min,
max_length=max,
repetition_penalty=vars.rep_pen,
top_p=vars.top_p,
top_p=top_p,
top_k=top_k,
tfs=tfs,
temperature=vars.temp,
bad_words_ids=vars.badwordsids,
use_cache=True,
@ -1016,6 +1038,8 @@ def sendtocolab(txt, min, max):
'rep_pen': vars.rep_pen,
'temperature': vars.temp,
'top_p': vars.top_p,
'top_k': vars.top_k,
'tfs': vars.tfs,
'numseqs': vars.numseqs,
'retfultxt': False
}
@ -1139,6 +1163,8 @@ def refresh_settings():
if(vars.model != "InferKit"):
emit('from_server', {'cmd': 'updatetemp', 'data': vars.temp})
emit('from_server', {'cmd': 'updatetopp', 'data': vars.top_p})
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k})
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs})
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen})
emit('from_server', {'cmd': 'updateoutlen', 'data': vars.genamt})
emit('from_server', {'cmd': 'updatetknmax', 'data': vars.max_length})

View File

@ -14,11 +14,33 @@ gensettingstf = [{
"unit": "float",
"label": "Top p Sampling",
"id": "settopp",
"min": 0.1,
"min": 0.0,
"max": 1.0,
"step": 0.05,
"default": 0.9,
"tooltip": "Used to discard unlikely text in the sampling process. Lower values will make text more predictable but can become repetitious."
},
{
"uitype": "slider",
"unit": "int",
"label": "Top k Sampling",
"id": "settopk",
"min": 0,
"max": 100,
"step": 1,
"default": 0,
"tooltip": "Alternative sampling method, can be combined with top_p."
},
{
"uitype": "slider",
"unit": "float",
"label": "Tail-free Sampling",
"id": "settfs",
"min": 0.0,
"max": 1.0,
"step": 0.05,
"default": 0.0,
"tooltip": "Alternative sampling method; it is recommended to disable (set to 0) top_p and top_k if using this. 0.95 is thought to be a good value."
},
{
"uitype": "slider",
@ -114,12 +136,34 @@ gensettingsik =[{
"unit": "float",
"label": "Top p Sampling",
"id": "settopp",
"min": 0.1,
"min": 0.0,
"max": 1.0,
"step": 0.05,
"default": 1.1,
"tooltip": "Used to discard unlikely text in the sampling process. Lower values will make text more predictable but can become repetitious."
},
{
"uitype": "slider",
"unit": "int",
"label": "Top k Sampling",
"id": "settopk",
"min": 0,
"max": 100,
"step": 1,
"default": 0,
"tooltip": "Alternative sampling method, can be combined with top_p."
},
{
"uitype": "slider",
"unit": "float",
"label": "Tail-free Sampling",
"id": "settfs",
"min": 0.0,
"max": 1.0,
"step": 0.05,
"default": 0.0,
"tooltip": "Alternative sampling method; it is recommended to disable (set to 0) top_p and top_k if using this. 0.95 is thought to be a good value."
},
{
"uitype": "slider",
"unit": "int",

View File

@ -706,8 +706,7 @@ $(document).ready(function(){
seqselcontents = $("#seqselcontents");
// Connect to SocketIO server
loc = window.document.location;
socket = io.connect(loc.href);
socket = io.connect(window.document.origin);
socket.on('from_server', function(msg) {
if(msg.cmd == "connected") {
@ -779,6 +778,14 @@ $(document).ready(function(){
// Send current top p value to input
$("#settopp").val(parseFloat(msg.data));
$("#settoppcur").html(msg.data);
} else if(msg.cmd == "updatetopk") {
// Send current top k value to input
$("#settopk").val(parseFloat(msg.data));
$("#settopkcur").html(msg.data);
} else if(msg.cmd == "updatetfs") {
// Send current tfs value to input
$("#settfs").val(parseFloat(msg.data));
$("#settfscur").html(msg.data);
} else if(msg.cmd == "updatereppen") {
// Send current rep pen value to input
$("#setreppen").val(parseFloat(msg.data));
@ -801,6 +808,12 @@ $(document).ready(function(){
} else if(msg.cmd == "setlabeltopp") {
// Update setting label with value from server
$("#settoppcur").html(msg.data);
} else if(msg.cmd == "setlabeltopk") {
// Update setting label with value from server
$("#settopkcur").html(msg.data);
} else if(msg.cmd == "setlabeltfs") {
// Update setting label with value from server
$("#settfscur").html(msg.data);
} else if(msg.cmd == "setlabelreppen") {
// Update setting label with value from server
$("#setreppencur").html(msg.data);