top_k and tfs support by Frogging101

Adds top_k and tfs support, also fixes a SocketIO error.
This commit is contained in:
henk717
2021-08-19 14:47:57 +02:00
parent caee12eae0
commit 72bfc417da
3 changed files with 88 additions and 5 deletions

View File

@ -64,6 +64,8 @@ class vars:
rep_pen = 1.1 # Default generator repetition_penalty
temp = 0.5 # Default generator temperature
top_p = 0.9 # Default generator top_p
top_k = 0 # Default generator top_k
tfs = 0.0 # Default generator tfs (tail-free sampling)
numseqs = 1 # Number of sequences to ask the generator to create
gamestarted = False # Whether the game has started (disables UI elements)
prompt = "" # Prompt
@ -449,6 +451,14 @@ def get_message(msg):
vars.top_p = float(msg['data'])
emit('from_server', {'cmd': 'setlabeltopp', 'data': msg['data']})
settingschanged()
elif(msg['cmd'] == 'settopk'):
vars.top_k = int(msg['data'])
emit('from_server', {'cmd': 'setlabeltopk', 'data': msg['data']})
settingschanged()
elif(msg['cmd'] == 'settfs'):
vars.tfs = float(msg['data'])
emit('from_server', {'cmd': 'setlabeltfs', 'data': msg['data']})
settingschanged()
elif(msg['cmd'] == 'setreppen'):
vars.rep_pen = float(msg['data'])
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']})
@ -588,6 +598,8 @@ def savesettings():
js["andepth"] = vars.andepth
js["temp"] = vars.temp
js["top_p"] = vars.top_p
js["top_k"] = vars.top_k
js["tfs"] = vars.tfs
js["rep_pen"] = vars.rep_pen
js["genamt"] = vars.genamt
js["max_length"] = vars.max_length
@ -623,6 +635,10 @@ def loadsettings():
vars.temp = js["temp"]
if("top_p" in js):
vars.top_p = js["top_p"]
if("top_k" in js):
vars.top_k = js["top_k"]
if("tfs" in js):
vars.tfs = js["tfs"]
if("rep_pen" in js):
vars.rep_pen = js["rep_pen"]
if("genamt" in js):
@ -922,13 +938,19 @@ def generate(txt, min, max):
# Submit input text to generator
try:
top_p = vars.top_p if vars.top_p > 0.0 else None
top_k = vars.top_k if vars.top_k > 0 else None
tfs = vars.tfs if vars.tfs > 0.0 else None
genout = generator(
txt,
do_sample=True,
min_length=min,
max_length=max,
repetition_penalty=vars.rep_pen,
top_p=vars.top_p,
top_p=top_p,
top_k=top_k,
tfs=tfs,
temperature=vars.temp,
bad_words_ids=vars.badwordsids,
use_cache=True,
@ -1016,6 +1038,8 @@ def sendtocolab(txt, min, max):
'rep_pen': vars.rep_pen,
'temperature': vars.temp,
'top_p': vars.top_p,
'top_k': vars.top_k,
'tfs': vars.tfs,
'numseqs': vars.numseqs,
'retfultxt': False
}
@ -1139,6 +1163,8 @@ def refresh_settings():
if(vars.model != "InferKit"):
emit('from_server', {'cmd': 'updatetemp', 'data': vars.temp})
emit('from_server', {'cmd': 'updatetopp', 'data': vars.top_p})
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k})
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs})
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen})
emit('from_server', {'cmd': 'updateoutlen', 'data': vars.genamt})
emit('from_server', {'cmd': 'updatetknmax', 'data': vars.max_length})