From 72bfc417da2451d10c7970d262c2a4cdf0315e8e Mon Sep 17 00:00:00 2001 From: henk717 Date: Thu, 19 Aug 2021 14:47:57 +0200 Subject: [PATCH] top_k and tfs support by Frogging101 Adds top_k and tfs support, also fixes a SocketIO error. --- aiserver.py | 28 ++++++++++++++++++++++++- gensettings.py | 48 +++++++++++++++++++++++++++++++++++++++++-- static/application.js | 17 +++++++++++++-- 3 files changed, 88 insertions(+), 5 deletions(-) diff --git a/aiserver.py b/aiserver.py index 8db499a1..bb6965b8 100644 --- a/aiserver.py +++ b/aiserver.py @@ -64,6 +64,8 @@ class vars: rep_pen = 1.1 # Default generator repetition_penalty temp = 0.5 # Default generator temperature top_p = 0.9 # Default generator top_p + top_k = 0 # Default generator top_k + tfs = 0.0 # Default generator tfs (tail-free sampling) numseqs = 1 # Number of sequences to ask the generator to create gamestarted = False # Whether the game has started (disables UI elements) prompt = "" # Prompt @@ -449,6 +451,14 @@ def get_message(msg): vars.top_p = float(msg['data']) emit('from_server', {'cmd': 'setlabeltopp', 'data': msg['data']}) settingschanged() + elif(msg['cmd'] == 'settopk'): + vars.top_k = int(msg['data']) + emit('from_server', {'cmd': 'setlabeltopk', 'data': msg['data']}) + settingschanged() + elif(msg['cmd'] == 'settfs'): + vars.tfs = float(msg['data']) + emit('from_server', {'cmd': 'setlabeltfs', 'data': msg['data']}) + settingschanged() elif(msg['cmd'] == 'setreppen'): vars.rep_pen = float(msg['data']) emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}) @@ -588,6 +598,8 @@ def savesettings(): js["andepth"] = vars.andepth js["temp"] = vars.temp js["top_p"] = vars.top_p + js["top_k"] = vars.top_k + js["tfs"] = vars.tfs js["rep_pen"] = vars.rep_pen js["genamt"] = vars.genamt js["max_length"] = vars.max_length @@ -623,6 +635,10 @@ def loadsettings(): vars.temp = js["temp"] if("top_p" in js): vars.top_p = js["top_p"] + if("top_k" in js): + vars.top_k = js["top_k"] + if("tfs" in js): + vars.tfs = js["tfs"] if("rep_pen" in js): vars.rep_pen = js["rep_pen"] if("genamt" in js): @@ -922,13 +938,19 @@ def generate(txt, min, max): # Submit input text to generator try: + top_p = vars.top_p if vars.top_p > 0.0 else None + top_k = vars.top_k if vars.top_k > 0 else None + tfs = vars.tfs if vars.tfs > 0.0 else None + genout = generator( txt, do_sample=True, min_length=min, max_length=max, repetition_penalty=vars.rep_pen, - top_p=vars.top_p, + top_p=top_p, + top_k=top_k, + tfs=tfs, temperature=vars.temp, bad_words_ids=vars.badwordsids, use_cache=True, @@ -1016,6 +1038,8 @@ def sendtocolab(txt, min, max): 'rep_pen': vars.rep_pen, 'temperature': vars.temp, 'top_p': vars.top_p, + 'top_k': vars.top_k, + 'tfs': vars.tfs, 'numseqs': vars.numseqs, 'retfultxt': False } @@ -1139,6 +1163,8 @@ def refresh_settings(): if(vars.model != "InferKit"): emit('from_server', {'cmd': 'updatetemp', 'data': vars.temp}) emit('from_server', {'cmd': 'updatetopp', 'data': vars.top_p}) + emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}) + emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}) emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}) emit('from_server', {'cmd': 'updateoutlen', 'data': vars.genamt}) emit('from_server', {'cmd': 'updatetknmax', 'data': vars.max_length}) diff --git a/gensettings.py b/gensettings.py index da3aaa76..ac531ba2 100644 --- a/gensettings.py +++ b/gensettings.py @@ -14,11 +14,33 @@ gensettingstf = [{ "unit": "float", "label": "Top p Sampling", "id": "settopp", - "min": 0.1, + "min": 0.0, "max": 1.0, "step": 0.05, "default": 0.9, "tooltip": "Used to discard unlikely text in the sampling process. Lower values will make text more predictable but can become repetitious." + }, + { + "uitype": "slider", + "unit": "int", + "label": "Top k Sampling", + "id": "settopk", + "min": 0, + "max": 100, + "step": 1, + "default": 0, + "tooltip": "Alternative sampling method, can be combined with top_p." + }, + { + "uitype": "slider", + "unit": "float", + "label": "Tail-free Sampling", + "id": "settfs", + "min": 0.0, + "max": 1.0, + "step": 0.05, + "default": 0.0, + "tooltip": "Alternative sampling method; it is recommended to disable (set to 0) top_p and top_k if using this. 0.95 is thought to be a good value." }, { "uitype": "slider", @@ -114,12 +136,34 @@ gensettingsik =[{ "unit": "float", "label": "Top p Sampling", "id": "settopp", - "min": 0.1, + "min": 0.0, "max": 1.0, "step": 0.05, "default": 1.1, "tooltip": "Used to discard unlikely text in the sampling process. Lower values will make text more predictable but can become repetitious." }, + { + "uitype": "slider", + "unit": "int", + "label": "Top k Sampling", + "id": "settopk", + "min": 0, + "max": 100, + "step": 1, + "default": 0, + "tooltip": "Alternative sampling method, can be combined with top_p." + }, + { + "uitype": "slider", + "unit": "float", + "label": "Tail-free Sampling", + "id": "settfs", + "min": 0.0, + "max": 1.0, + "step": 0.05, + "default": 0.0, + "tooltip": "Alternative sampling method; it is recommended to disable (set to 0) top_p and top_k if using this. 0.95 is thought to be a good value." + }, { "uitype": "slider", "unit": "int", diff --git a/static/application.js b/static/application.js index 2b9bd2db..4de2409d 100644 --- a/static/application.js +++ b/static/application.js @@ -706,8 +706,7 @@ $(document).ready(function(){ seqselcontents = $("#seqselcontents"); // Connect to SocketIO server - loc = window.document.location; - socket = io.connect(loc.href); + socket = io.connect(window.document.origin); socket.on('from_server', function(msg) { if(msg.cmd == "connected") { @@ -779,6 +778,14 @@ $(document).ready(function(){ // Send current top p value to input $("#settopp").val(parseFloat(msg.data)); $("#settoppcur").html(msg.data); + } else if(msg.cmd == "updatetopk") { + // Send current top k value to input + $("#settopk").val(parseFloat(msg.data)); + $("#settopkcur").html(msg.data); + } else if(msg.cmd == "updatetfs") { + // Send current tfs value to input + $("#settfs").val(parseFloat(msg.data)); + $("#settfscur").html(msg.data); } else if(msg.cmd == "updatereppen") { // Send current rep pen value to input $("#setreppen").val(parseFloat(msg.data)); @@ -801,6 +808,12 @@ $(document).ready(function(){ } else if(msg.cmd == "setlabeltopp") { // Update setting label with value from server $("#settoppcur").html(msg.data); + } else if(msg.cmd == "setlabeltopk") { + // Update setting label with value from server + $("#settopkcur").html(msg.data); + } else if(msg.cmd == "setlabeltfs") { + // Update setting label with value from server + $("#settfscur").html(msg.data); } else if(msg.cmd == "setlabelreppen") { // Update setting label with value from server $("#setreppencur").html(msg.data);