mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
top_k and tfs support by Frogging101
Adds top_k and tfs support, also fixes a SocketIO error.
This commit is contained in:
28
aiserver.py
28
aiserver.py
@ -64,6 +64,8 @@ class vars:
|
||||
rep_pen = 1.1 # Default generator repetition_penalty
|
||||
temp = 0.5 # Default generator temperature
|
||||
top_p = 0.9 # Default generator top_p
|
||||
top_k = 0 # Default generator top_k
|
||||
tfs = 0.0 # Default generator tfs (tail-free sampling)
|
||||
numseqs = 1 # Number of sequences to ask the generator to create
|
||||
gamestarted = False # Whether the game has started (disables UI elements)
|
||||
prompt = "" # Prompt
|
||||
@ -449,6 +451,14 @@ def get_message(msg):
|
||||
vars.top_p = float(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabeltopp', 'data': msg['data']})
|
||||
settingschanged()
|
||||
elif(msg['cmd'] == 'settopk'):
|
||||
vars.top_k = int(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabeltopk', 'data': msg['data']})
|
||||
settingschanged()
|
||||
elif(msg['cmd'] == 'settfs'):
|
||||
vars.tfs = float(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabeltfs', 'data': msg['data']})
|
||||
settingschanged()
|
||||
elif(msg['cmd'] == 'setreppen'):
|
||||
vars.rep_pen = float(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']})
|
||||
@ -588,6 +598,8 @@ def savesettings():
|
||||
js["andepth"] = vars.andepth
|
||||
js["temp"] = vars.temp
|
||||
js["top_p"] = vars.top_p
|
||||
js["top_k"] = vars.top_k
|
||||
js["tfs"] = vars.tfs
|
||||
js["rep_pen"] = vars.rep_pen
|
||||
js["genamt"] = vars.genamt
|
||||
js["max_length"] = vars.max_length
|
||||
@ -623,6 +635,10 @@ def loadsettings():
|
||||
vars.temp = js["temp"]
|
||||
if("top_p" in js):
|
||||
vars.top_p = js["top_p"]
|
||||
if("top_k" in js):
|
||||
vars.top_k = js["top_k"]
|
||||
if("tfs" in js):
|
||||
vars.tfs = js["tfs"]
|
||||
if("rep_pen" in js):
|
||||
vars.rep_pen = js["rep_pen"]
|
||||
if("genamt" in js):
|
||||
@ -922,13 +938,19 @@ def generate(txt, min, max):
|
||||
|
||||
# Submit input text to generator
|
||||
try:
|
||||
top_p = vars.top_p if vars.top_p > 0.0 else None
|
||||
top_k = vars.top_k if vars.top_k > 0 else None
|
||||
tfs = vars.tfs if vars.tfs > 0.0 else None
|
||||
|
||||
genout = generator(
|
||||
txt,
|
||||
do_sample=True,
|
||||
min_length=min,
|
||||
max_length=max,
|
||||
repetition_penalty=vars.rep_pen,
|
||||
top_p=vars.top_p,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
tfs=tfs,
|
||||
temperature=vars.temp,
|
||||
bad_words_ids=vars.badwordsids,
|
||||
use_cache=True,
|
||||
@ -1016,6 +1038,8 @@ def sendtocolab(txt, min, max):
|
||||
'rep_pen': vars.rep_pen,
|
||||
'temperature': vars.temp,
|
||||
'top_p': vars.top_p,
|
||||
'top_k': vars.top_k,
|
||||
'tfs': vars.tfs,
|
||||
'numseqs': vars.numseqs,
|
||||
'retfultxt': False
|
||||
}
|
||||
@ -1139,6 +1163,8 @@ def refresh_settings():
|
||||
if(vars.model != "InferKit"):
|
||||
emit('from_server', {'cmd': 'updatetemp', 'data': vars.temp})
|
||||
emit('from_server', {'cmd': 'updatetopp', 'data': vars.top_p})
|
||||
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k})
|
||||
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs})
|
||||
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen})
|
||||
emit('from_server', {'cmd': 'updateoutlen', 'data': vars.genamt})
|
||||
emit('from_server', {'cmd': 'updatetknmax', 'data': vars.max_length})
|
||||
|
Reference in New Issue
Block a user