Implemented token streaming from one-some

This commit is contained in:
ebolam
2022-08-05 12:35:38 -04:00
parent 3285926164
commit 2ee4371622
5 changed files with 100 additions and 6 deletions

View File

@@ -675,6 +675,7 @@ def savesettings():
js["autosave"] = koboldai_vars.autosave js["autosave"] = koboldai_vars.autosave
js["welcome"] = koboldai_vars.welcome js["welcome"] = koboldai_vars.welcome
js["newlinemode"] = koboldai_vars.newlinemode js["newlinemode"] = koboldai_vars.newlinemode
js["output_streaming"] = koboldai_vars.output_streaming
js["antemplate"] = koboldai_vars.setauthornotetemplate js["antemplate"] = koboldai_vars.setauthornotetemplate
@@ -779,6 +780,8 @@ def processsettings(js):
koboldai_vars.newlinemode = js["newlinemode"] koboldai_vars.newlinemode = js["newlinemode"]
if("welcome" in js): if("welcome" in js):
koboldai_vars.welcome = js["welcome"] koboldai_vars.welcome = js["welcome"]
if("output_streaming" in js):
koboldai_vars.autosave = js["output_streaming"]
if("antemplate" in js): if("antemplate" in js):
koboldai_vars.setauthornotetemplate = js["antemplate"] koboldai_vars.setauthornotetemplate = js["antemplate"]
@@ -1412,6 +1415,33 @@ def patch_transformers():
new_init.old_init = transformers.generation_logits_process.NoBadWordsLogitsProcessor.__init__ new_init.old_init = transformers.generation_logits_process.NoBadWordsLogitsProcessor.__init__
transformers.generation_logits_process.NoBadWordsLogitsProcessor.__init__ = new_init transformers.generation_logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
class TokenStreamer(StoppingCriteria):
# A StoppingCriteria is used here because it seems to run after
# everything has been evaluated score-wise.
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
**kwargs,
) -> bool:
if (not koboldai_vars.output_streaming):
return False
#for batch, ids in enumerate(input_ids):
#tokenizer_text = utils.decodenewlines(tokenizer.decode(ids[-1]))
#koboldai_vars.actions.stream_token(tokenizer_text, batch=batch)
koboldai_vars.actions.stream_tokens([utils.decodenewlines(tokenizer.decode(x[-1])) for x in input_ids])
#if len(input_ids) > 1:
# koboldai_vars.actions.clear_unused_options()
# koboldai_vars.actions.append_options([utils.decodenewlines(tokenizer.decode(x[-1])) for x in input_ids])
#else:
# koboldai_vars.actions[koboldai_vars.actions.action_count+1] = utils.decodenewlines(tokenizer.decode(input_ids[0, -1]))
return False
# Sets up dynamic world info scanner # Sets up dynamic world info scanner
class DynamicWorldInfoScanCriteria(StoppingCriteria): class DynamicWorldInfoScanCriteria(StoppingCriteria):
@@ -1467,6 +1497,8 @@ def patch_transformers():
excluded_world_info=self.kai_scanner_excluded_world_info, excluded_world_info=self.kai_scanner_excluded_world_info,
) )
stopping_criteria.insert(0, self.kai_scanner) stopping_criteria.insert(0, self.kai_scanner)
token_streamer = TokenStreamer(tokenizer=tokenizer)
stopping_criteria.insert(0, token_streamer)
return stopping_criteria return stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
@@ -2616,6 +2648,7 @@ def lua_has_setting(setting):
"rmspch", "rmspch",
"adsnsp", "adsnsp",
"singleline", "singleline",
"output_streaming",
) )
#==================================================================# #==================================================================#
@@ -2647,6 +2680,7 @@ def lua_get_setting(setting):
if(setting in ("frmtrmspch", "rmspch")): return koboldai_vars.formatoptns["frmttrmspch"] if(setting in ("frmtrmspch", "rmspch")): return koboldai_vars.formatoptns["frmttrmspch"]
if(setting in ("frmtadsnsp", "adsnsp")): return koboldai_vars.formatoptns["frmtadsnsp"] if(setting in ("frmtadsnsp", "adsnsp")): return koboldai_vars.formatoptns["frmtadsnsp"]
if(setting in ("frmtsingleline", "singleline")): return koboldai_vars.formatoptns["singleline"] if(setting in ("frmtsingleline", "singleline")): return koboldai_vars.formatoptns["singleline"]
if(setting == "outputstreaming"): koboldai_vars.output_streaming = v
#==================================================================# #==================================================================#
# Set the setting with the given name if it exists # Set the setting with the given name if it exists
@@ -3391,6 +3425,10 @@ def get_message(msg):
koboldai_vars.nogenmod = msg['data'] koboldai_vars.nogenmod = msg['data']
settingschanged() settingschanged()
refresh_settings() refresh_settings()
elif(msg['cmd'] == 'setoutputstreaming'):
koboldai_vars.output_streaming = msg['data']
settingschanged()
refresh_settings()
elif(not koboldai_vars.host and msg['cmd'] == 'importwi'): elif(not koboldai_vars.host and msg['cmd'] == 'importwi'):
wiimportrequest() wiimportrequest()
elif(msg['cmd'] == 'debug'): elif(msg['cmd'] == 'debug'):
@@ -4513,6 +4551,7 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatefrmtrmspch', 'data': koboldai_vars.formatoptns["frmtrmspch"]}, broadcast=True, room="UI_1") emit('from_server', {'cmd': 'updatefrmtrmspch', 'data': koboldai_vars.formatoptns["frmtrmspch"]}, broadcast=True, room="UI_1")
emit('from_server', {'cmd': 'updatefrmtadsnsp', 'data': koboldai_vars.formatoptns["frmtadsnsp"]}, broadcast=True, room="UI_1") emit('from_server', {'cmd': 'updatefrmtadsnsp', 'data': koboldai_vars.formatoptns["frmtadsnsp"]}, broadcast=True, room="UI_1")
emit('from_server', {'cmd': 'updatesingleline', 'data': koboldai_vars.formatoptns["singleline"]}, broadcast=True, room="UI_1") emit('from_server', {'cmd': 'updatesingleline', 'data': koboldai_vars.formatoptns["singleline"]}, broadcast=True, room="UI_1")
emit('from_server', {'cmd': 'updateoutputstreaming', 'data': koboldai_vars.output_streaming}, broadcast=True, room="UI_1")
# Allow toggle events again # Allow toggle events again
emit('from_server', {'cmd': 'allowtoggle', 'data': True}, broadcast=True, room="UI_1") emit('from_server', {'cmd': 'allowtoggle', 'data': True}, broadcast=True, room="UI_1")

View File

@@ -331,7 +331,21 @@ gensettingstf = [
"classname": "story", "classname": "story",
"name": "actionmode", "name": "actionmode",
'children': [{'text': 'Story', 'value': 0}, {'text':'Adventure','value':1}, {'text':'Chat', 'value':2}] 'children': [{'text': 'Story', 'value': 0}, {'text':'Adventure','value':1}, {'text':'Chat', 'value':2}]
} },
{
"uitype": "toggle",
"unit": "bool",
"label": "Token Streaming",
"id": "setoutputstreaming",
"min": 0,
"max": 1,
"step": 1,
"default": 0,
"tooltip": "Shows outputs to you as they are made.",
"menu_path": "User",
"classname": "user",
"name": "output_streaming"
}
] ]
gensettingsik =[{ gensettingsik =[{
@@ -520,7 +534,21 @@ gensettingsik =[{
"menu_path": "User", "menu_path": "User",
"classname": "user", "classname": "user",
"name": "debug" "name": "debug"
} },
{
"uitype": "toggle",
"unit": "bool",
"label": "Token Streaming",
"id": "setoutputstreaming",
"min": 0,
"max": 1,
"step": 1,
"default": 0,
"tooltip": "Shows outputs to you as they are made.",
"menu_path": "User",
"classname": "user",
"name": "output_streaming"
}
] ]
formatcontrols = [{ formatcontrols = [{

View File

@@ -444,6 +444,7 @@ class user_settings(settings):
self.rngpersist = False self.rngpersist = False
self.nogenmod = False self.nogenmod = False
self.debug = False # If set to true, will send debug information to the client for display self.debug = False # If set to true, will send debug information to the client for display
self.output_streaming = True
def __setattr__(self, name, value): def __setattr__(self, name, value):
@@ -563,6 +564,7 @@ class KoboldStoryRegister(object):
self.actions[i]["Options"].append({"text": old_text, "Pinned": False, "Previous Selection": False, "Edited": True}) self.actions[i]["Options"].append({"text": old_text, "Pinned": False, "Previous Selection": False, "Edited": True})
else: else:
old_text = None old_text = None
old_length = None
self.actions[i] = {"Selected Text": text, "Options": []} self.actions[i] = {"Selected Text": text, "Options": []}
if self.tokenizer is not None: if self.tokenizer is not None:
@@ -730,9 +732,6 @@ class KoboldStoryRegister(object):
text = self.actions[self.action_count]['Selected Text'] text = self.actions[self.action_count]['Selected Text']
length = self.actions[self.action_count]['Selected Text Length'] length = self.actions[self.action_count]['Selected Text Length']
self.delete_action(self.action_count) self.delete_action(self.action_count)
process_variable_changes(self.socketio, "actions", "Selected Text", {"id": self.action_count, "text": None}, {"id": self.action_count, "text": text})
process_variable_changes(self.socketio, "actions", 'Selected Text Length', {"id": self.action_count, 'length': 0}, {"id": self.action_count, 'length': length})
self.set_game_saved()
return text return text
else: else:
return None return None
@@ -777,6 +776,33 @@ class KoboldStoryRegister(object):
for key in self.actions: for key in self.actions:
self.actions[key]['Selected Text Length'] = None self.actions[key]['Selected Text Length'] = None
def stream_tokens(self, text_list):
if len(text_list) > 1:
if self.action_count+1 in self.actions:
for i in range(len(text_list)):
for j in range(len(self.actions[self.action_count+1]['Options'])):
if 'stream_id' in self.actions[self.action_count+1]['Options'][j]:
if self.actions[self.action_count+1]['Options'][j]['stream_id'] == i:
self.actions[self.action_count+1]['Options'][i]['text'] = "{}{}".format(self.actions[self.action_count+1]['Options'][i]['text'], text_list[i])
else:
self.actions[self.action_count+1] = {"Selected Text": "", "Selected Text Length": 0, "Options": []}
for i in range(len(text_list)):
self.actions[self.action_count+1]['Options'].append({"text": text_list[i], "Pinned": False, "Previous Selection": False, "Edited": False, "stream_id": i})
process_variable_changes(self.socketio, "actions", "Options", {"id": self.action_count+1, "options": self.actions[self.action_count+1]["Options"]}, {"id": self.action_count+1, "options": None})
else:
#We're streaming single options so our output is our selected
if self.tokenizer is not None:
selected_text_length = len(self.tokenizer.encode(text_list[0]))
else:
selected_text_length = 0
if self.action_count+1 in self.actions:
self.actions[self.action_count+1]['Selected Text'] = "{}{}".format(self.actions[self.action_count+1]['Selected Text'], text_list[0])
else:
self.actions[self.action_count+1] = {"Selected Text": text_list[0], "Selected Text Length": selected_text_length, "Options": []}
process_variable_changes(self.socketio, "actions", "Selected Text", {"id": self.action_count+1, "text": self.actions[self.action_count+1]['Selected Text']}, None)
process_variable_changes(self.socketio, "actions", 'Selected Text Length', {"id": self.action_count+1, 'length': self.actions[self.action_count+1]['Selected Text Length']}, {"id": self.action_count, 'length': 0})
def __setattr__(self, name, value): def __setattr__(self, name, value):
new_variable = name not in self.__dict__ new_variable = name not in self.__dict__

View File

@@ -591,6 +591,7 @@ body {
.sequence { .sequence {
border: 1px solid #959595; border: 1px solid #959595;
border-radius: 5px; border-radius: 5px;
width=100%;
grid-area: text; grid-area: text;
padding: 0px; padding: 0px;
background-color: var(--options_background); background-color: var(--options_background);

View File

@@ -25,7 +25,7 @@ socket.on("delete_world_info_entry", function(data){document.getElementById("wor
//socket.onAny(function(event_name, data) {console.log({"event": event_name, "class": data.classname, "data": data});}); //socket.onAny(function(event_name, data) {console.log({"event": event_name, "class": data.classname, "data": data});});
var backend_vars = {}; var backend_vars = {};
var presets = {} var presets = {};
var current_chunk_number = null; var current_chunk_number = null;
var ai_busy_start = Date.now(); var ai_busy_start = Date.now();
var popup_deleteable = false; var popup_deleteable = false;