top_k and tfs support by Frogging101
Adds top_k and tfs support, also fixes a SocketIO error.
This commit is contained in:
parent
caee12eae0
commit
72bfc417da
28
aiserver.py
28
aiserver.py
|
@ -64,6 +64,8 @@ class vars:
|
|||
rep_pen = 1.1 # Default generator repetition_penalty
|
||||
temp = 0.5 # Default generator temperature
|
||||
top_p = 0.9 # Default generator top_p
|
||||
top_k = 0 # Default generator top_k
|
||||
tfs = 0.0 # Default generator tfs (tail-free sampling)
|
||||
numseqs = 1 # Number of sequences to ask the generator to create
|
||||
gamestarted = False # Whether the game has started (disables UI elements)
|
||||
prompt = "" # Prompt
|
||||
|
@ -449,6 +451,14 @@ def get_message(msg):
|
|||
vars.top_p = float(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabeltopp', 'data': msg['data']})
|
||||
settingschanged()
|
||||
elif(msg['cmd'] == 'settopk'):
|
||||
vars.top_k = int(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabeltopk', 'data': msg['data']})
|
||||
settingschanged()
|
||||
elif(msg['cmd'] == 'settfs'):
|
||||
vars.tfs = float(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabeltfs', 'data': msg['data']})
|
||||
settingschanged()
|
||||
elif(msg['cmd'] == 'setreppen'):
|
||||
vars.rep_pen = float(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']})
|
||||
|
@ -588,6 +598,8 @@ def savesettings():
|
|||
js["andepth"] = vars.andepth
|
||||
js["temp"] = vars.temp
|
||||
js["top_p"] = vars.top_p
|
||||
js["top_k"] = vars.top_k
|
||||
js["tfs"] = vars.tfs
|
||||
js["rep_pen"] = vars.rep_pen
|
||||
js["genamt"] = vars.genamt
|
||||
js["max_length"] = vars.max_length
|
||||
|
@ -623,6 +635,10 @@ def loadsettings():
|
|||
vars.temp = js["temp"]
|
||||
if("top_p" in js):
|
||||
vars.top_p = js["top_p"]
|
||||
if("top_k" in js):
|
||||
vars.top_k = js["top_k"]
|
||||
if("tfs" in js):
|
||||
vars.tfs = js["tfs"]
|
||||
if("rep_pen" in js):
|
||||
vars.rep_pen = js["rep_pen"]
|
||||
if("genamt" in js):
|
||||
|
@ -922,13 +938,19 @@ def generate(txt, min, max):
|
|||
|
||||
# Submit input text to generator
|
||||
try:
|
||||
top_p = vars.top_p if vars.top_p > 0.0 else None
|
||||
top_k = vars.top_k if vars.top_k > 0 else None
|
||||
tfs = vars.tfs if vars.tfs > 0.0 else None
|
||||
|
||||
genout = generator(
|
||||
txt,
|
||||
do_sample=True,
|
||||
min_length=min,
|
||||
max_length=max,
|
||||
repetition_penalty=vars.rep_pen,
|
||||
top_p=vars.top_p,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
tfs=tfs,
|
||||
temperature=vars.temp,
|
||||
bad_words_ids=vars.badwordsids,
|
||||
use_cache=True,
|
||||
|
@ -1016,6 +1038,8 @@ def sendtocolab(txt, min, max):
|
|||
'rep_pen': vars.rep_pen,
|
||||
'temperature': vars.temp,
|
||||
'top_p': vars.top_p,
|
||||
'top_k': vars.top_k,
|
||||
'tfs': vars.tfs,
|
||||
'numseqs': vars.numseqs,
|
||||
'retfultxt': False
|
||||
}
|
||||
|
@ -1139,6 +1163,8 @@ def refresh_settings():
|
|||
if(vars.model != "InferKit"):
|
||||
emit('from_server', {'cmd': 'updatetemp', 'data': vars.temp})
|
||||
emit('from_server', {'cmd': 'updatetopp', 'data': vars.top_p})
|
||||
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k})
|
||||
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs})
|
||||
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen})
|
||||
emit('from_server', {'cmd': 'updateoutlen', 'data': vars.genamt})
|
||||
emit('from_server', {'cmd': 'updatetknmax', 'data': vars.max_length})
|
||||
|
|
|
@ -14,11 +14,33 @@ gensettingstf = [{
|
|||
"unit": "float",
|
||||
"label": "Top p Sampling",
|
||||
"id": "settopp",
|
||||
"min": 0.1,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.05,
|
||||
"default": 0.9,
|
||||
"tooltip": "Used to discard unlikely text in the sampling process. Lower values will make text more predictable but can become repetitious."
|
||||
},
|
||||
{
|
||||
"uitype": "slider",
|
||||
"unit": "int",
|
||||
"label": "Top k Sampling",
|
||||
"id": "settopk",
|
||||
"min": 0,
|
||||
"max": 100,
|
||||
"step": 1,
|
||||
"default": 0,
|
||||
"tooltip": "Alternative sampling method, can be combined with top_p."
|
||||
},
|
||||
{
|
||||
"uitype": "slider",
|
||||
"unit": "float",
|
||||
"label": "Tail-free Sampling",
|
||||
"id": "settfs",
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.05,
|
||||
"default": 0.0,
|
||||
"tooltip": "Alternative sampling method; it is recommended to disable (set to 0) top_p and top_k if using this. 0.95 is thought to be a good value."
|
||||
},
|
||||
{
|
||||
"uitype": "slider",
|
||||
|
@ -114,12 +136,34 @@ gensettingsik =[{
|
|||
"unit": "float",
|
||||
"label": "Top p Sampling",
|
||||
"id": "settopp",
|
||||
"min": 0.1,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.05,
|
||||
"default": 1.1,
|
||||
"tooltip": "Used to discard unlikely text in the sampling process. Lower values will make text more predictable but can become repetitious."
|
||||
},
|
||||
{
|
||||
"uitype": "slider",
|
||||
"unit": "int",
|
||||
"label": "Top k Sampling",
|
||||
"id": "settopk",
|
||||
"min": 0,
|
||||
"max": 100,
|
||||
"step": 1,
|
||||
"default": 0,
|
||||
"tooltip": "Alternative sampling method, can be combined with top_p."
|
||||
},
|
||||
{
|
||||
"uitype": "slider",
|
||||
"unit": "float",
|
||||
"label": "Tail-free Sampling",
|
||||
"id": "settfs",
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.05,
|
||||
"default": 0.0,
|
||||
"tooltip": "Alternative sampling method; it is recommended to disable (set to 0) top_p and top_k if using this. 0.95 is thought to be a good value."
|
||||
},
|
||||
{
|
||||
"uitype": "slider",
|
||||
"unit": "int",
|
||||
|
|
|
@ -706,8 +706,7 @@ $(document).ready(function(){
|
|||
seqselcontents = $("#seqselcontents");
|
||||
|
||||
// Connect to SocketIO server
|
||||
loc = window.document.location;
|
||||
socket = io.connect(loc.href);
|
||||
socket = io.connect(window.document.origin);
|
||||
|
||||
socket.on('from_server', function(msg) {
|
||||
if(msg.cmd == "connected") {
|
||||
|
@ -779,6 +778,14 @@ $(document).ready(function(){
|
|||
// Send current top p value to input
|
||||
$("#settopp").val(parseFloat(msg.data));
|
||||
$("#settoppcur").html(msg.data);
|
||||
} else if(msg.cmd == "updatetopk") {
|
||||
// Send current top k value to input
|
||||
$("#settopk").val(parseFloat(msg.data));
|
||||
$("#settopkcur").html(msg.data);
|
||||
} else if(msg.cmd == "updatetfs") {
|
||||
// Send current tfs value to input
|
||||
$("#settfs").val(parseFloat(msg.data));
|
||||
$("#settfscur").html(msg.data);
|
||||
} else if(msg.cmd == "updatereppen") {
|
||||
// Send current rep pen value to input
|
||||
$("#setreppen").val(parseFloat(msg.data));
|
||||
|
@ -801,6 +808,12 @@ $(document).ready(function(){
|
|||
} else if(msg.cmd == "setlabeltopp") {
|
||||
// Update setting label with value from server
|
||||
$("#settoppcur").html(msg.data);
|
||||
} else if(msg.cmd == "setlabeltopk") {
|
||||
// Update setting label with value from server
|
||||
$("#settopkcur").html(msg.data);
|
||||
} else if(msg.cmd == "setlabeltfs") {
|
||||
// Update setting label with value from server
|
||||
$("#settfscur").html(msg.data);
|
||||
} else if(msg.cmd == "setlabelreppen") {
|
||||
// Update setting label with value from server
|
||||
$("#setreppencur").html(msg.data);
|
||||
|
|
Loading…
Reference in New Issue