diff --git a/aiserver.py b/aiserver.py index b1da8d95..1e0cc1a6 100644 --- a/aiserver.py +++ b/aiserver.py @@ -675,6 +675,7 @@ def savesettings(): js["autosave"] = koboldai_vars.autosave js["welcome"] = koboldai_vars.welcome js["newlinemode"] = koboldai_vars.newlinemode + js["output_streaming"] = koboldai_vars.output_streaming js["antemplate"] = koboldai_vars.setauthornotetemplate @@ -779,6 +780,8 @@ def processsettings(js): koboldai_vars.newlinemode = js["newlinemode"] if("welcome" in js): koboldai_vars.welcome = js["welcome"] + if("output_streaming" in js): + koboldai_vars.autosave = js["output_streaming"] if("antemplate" in js): koboldai_vars.setauthornotetemplate = js["antemplate"] @@ -1412,6 +1415,33 @@ def patch_transformers(): new_init.old_init = transformers.generation_logits_process.NoBadWordsLogitsProcessor.__init__ transformers.generation_logits_process.NoBadWordsLogitsProcessor.__init__ = new_init + class TokenStreamer(StoppingCriteria): + # A StoppingCriteria is used here because it seems to run after + # everything has been evaluated score-wise. + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def __call__( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + **kwargs, + ) -> bool: + if (not koboldai_vars.output_streaming): + return False + + #for batch, ids in enumerate(input_ids): + #tokenizer_text = utils.decodenewlines(tokenizer.decode(ids[-1])) + #koboldai_vars.actions.stream_token(tokenizer_text, batch=batch) + + koboldai_vars.actions.stream_tokens([utils.decodenewlines(tokenizer.decode(x[-1])) for x in input_ids]) + #if len(input_ids) > 1: + # koboldai_vars.actions.clear_unused_options() + # koboldai_vars.actions.append_options([utils.decodenewlines(tokenizer.decode(x[-1])) for x in input_ids]) + #else: + # koboldai_vars.actions[koboldai_vars.actions.action_count+1] = utils.decodenewlines(tokenizer.decode(input_ids[0, -1])) + + return False # Sets up dynamic world info scanner class DynamicWorldInfoScanCriteria(StoppingCriteria): @@ -1467,6 +1497,8 @@ def patch_transformers(): excluded_world_info=self.kai_scanner_excluded_world_info, ) stopping_criteria.insert(0, self.kai_scanner) + token_streamer = TokenStreamer(tokenizer=tokenizer) + stopping_criteria.insert(0, token_streamer) return stopping_criteria transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria @@ -2616,6 +2648,7 @@ def lua_has_setting(setting): "rmspch", "adsnsp", "singleline", + "output_streaming", ) #==================================================================# @@ -2647,6 +2680,7 @@ def lua_get_setting(setting): if(setting in ("frmtrmspch", "rmspch")): return koboldai_vars.formatoptns["frmttrmspch"] if(setting in ("frmtadsnsp", "adsnsp")): return koboldai_vars.formatoptns["frmtadsnsp"] if(setting in ("frmtsingleline", "singleline")): return koboldai_vars.formatoptns["singleline"] + if(setting == "outputstreaming"): koboldai_vars.output_streaming = v #==================================================================# # Set the setting with the given name if it exists @@ -3391,6 +3425,10 @@ def get_message(msg): koboldai_vars.nogenmod = msg['data'] settingschanged() refresh_settings() + elif(msg['cmd'] == 'setoutputstreaming'): + koboldai_vars.output_streaming = msg['data'] + settingschanged() + refresh_settings() elif(not koboldai_vars.host and msg['cmd'] == 'importwi'): wiimportrequest() elif(msg['cmd'] == 'debug'): @@ -4513,6 +4551,7 @@ def refresh_settings(): emit('from_server', {'cmd': 'updatefrmtrmspch', 'data': koboldai_vars.formatoptns["frmtrmspch"]}, broadcast=True, room="UI_1") emit('from_server', {'cmd': 'updatefrmtadsnsp', 'data': koboldai_vars.formatoptns["frmtadsnsp"]}, broadcast=True, room="UI_1") emit('from_server', {'cmd': 'updatesingleline', 'data': koboldai_vars.formatoptns["singleline"]}, broadcast=True, room="UI_1") + emit('from_server', {'cmd': 'updateoutputstreaming', 'data': koboldai_vars.output_streaming}, broadcast=True, room="UI_1") # Allow toggle events again emit('from_server', {'cmd': 'allowtoggle', 'data': True}, broadcast=True, room="UI_1") diff --git a/gensettings.py b/gensettings.py index 94fa312d..16a81950 100644 --- a/gensettings.py +++ b/gensettings.py @@ -331,7 +331,21 @@ gensettingstf = [ "classname": "story", "name": "actionmode", 'children': [{'text': 'Story', 'value': 0}, {'text':'Adventure','value':1}, {'text':'Chat', 'value':2}] - } + }, + { + "uitype": "toggle", + "unit": "bool", + "label": "Token Streaming", + "id": "setoutputstreaming", + "min": 0, + "max": 1, + "step": 1, + "default": 0, + "tooltip": "Shows outputs to you as they are made.", + "menu_path": "User", + "classname": "user", + "name": "output_streaming" + } ] gensettingsik =[{ @@ -520,7 +534,21 @@ gensettingsik =[{ "menu_path": "User", "classname": "user", "name": "debug" - } + }, + { + "uitype": "toggle", + "unit": "bool", + "label": "Token Streaming", + "id": "setoutputstreaming", + "min": 0, + "max": 1, + "step": 1, + "default": 0, + "tooltip": "Shows outputs to you as they are made.", + "menu_path": "User", + "classname": "user", + "name": "output_streaming" + } ] formatcontrols = [{ diff --git a/koboldai_settings.py b/koboldai_settings.py index 1d386858..8287647d 100644 --- a/koboldai_settings.py +++ b/koboldai_settings.py @@ -444,6 +444,7 @@ class user_settings(settings): self.rngpersist = False self.nogenmod = False self.debug = False # If set to true, will send debug information to the client for display + self.output_streaming = True def __setattr__(self, name, value): @@ -563,6 +564,7 @@ class KoboldStoryRegister(object): self.actions[i]["Options"].append({"text": old_text, "Pinned": False, "Previous Selection": False, "Edited": True}) else: old_text = None + old_length = None self.actions[i] = {"Selected Text": text, "Options": []} if self.tokenizer is not None: @@ -730,9 +732,6 @@ class KoboldStoryRegister(object): text = self.actions[self.action_count]['Selected Text'] length = self.actions[self.action_count]['Selected Text Length'] self.delete_action(self.action_count) - process_variable_changes(self.socketio, "actions", "Selected Text", {"id": self.action_count, "text": None}, {"id": self.action_count, "text": text}) - process_variable_changes(self.socketio, "actions", 'Selected Text Length', {"id": self.action_count, 'length': 0}, {"id": self.action_count, 'length': length}) - self.set_game_saved() return text else: return None @@ -777,6 +776,33 @@ class KoboldStoryRegister(object): for key in self.actions: self.actions[key]['Selected Text Length'] = None + def stream_tokens(self, text_list): + if len(text_list) > 1: + if self.action_count+1 in self.actions: + for i in range(len(text_list)): + for j in range(len(self.actions[self.action_count+1]['Options'])): + if 'stream_id' in self.actions[self.action_count+1]['Options'][j]: + if self.actions[self.action_count+1]['Options'][j]['stream_id'] == i: + self.actions[self.action_count+1]['Options'][i]['text'] = "{}{}".format(self.actions[self.action_count+1]['Options'][i]['text'], text_list[i]) + else: + self.actions[self.action_count+1] = {"Selected Text": "", "Selected Text Length": 0, "Options": []} + for i in range(len(text_list)): + self.actions[self.action_count+1]['Options'].append({"text": text_list[i], "Pinned": False, "Previous Selection": False, "Edited": False, "stream_id": i}) + + process_variable_changes(self.socketio, "actions", "Options", {"id": self.action_count+1, "options": self.actions[self.action_count+1]["Options"]}, {"id": self.action_count+1, "options": None}) + else: + #We're streaming single options so our output is our selected + if self.tokenizer is not None: + selected_text_length = len(self.tokenizer.encode(text_list[0])) + else: + selected_text_length = 0 + if self.action_count+1 in self.actions: + self.actions[self.action_count+1]['Selected Text'] = "{}{}".format(self.actions[self.action_count+1]['Selected Text'], text_list[0]) + else: + self.actions[self.action_count+1] = {"Selected Text": text_list[0], "Selected Text Length": selected_text_length, "Options": []} + + process_variable_changes(self.socketio, "actions", "Selected Text", {"id": self.action_count+1, "text": self.actions[self.action_count+1]['Selected Text']}, None) + process_variable_changes(self.socketio, "actions", 'Selected Text Length', {"id": self.action_count+1, 'length': self.actions[self.action_count+1]['Selected Text Length']}, {"id": self.action_count, 'length': 0}) def __setattr__(self, name, value): new_variable = name not in self.__dict__ diff --git a/static/koboldai.css b/static/koboldai.css index b8628049..5ed127b1 100644 --- a/static/koboldai.css +++ b/static/koboldai.css @@ -591,6 +591,7 @@ body { .sequence { border: 1px solid #959595; border-radius: 5px; + width=100%; grid-area: text; padding: 0px; background-color: var(--options_background); diff --git a/static/koboldai.js b/static/koboldai.js index a83f004d..6a2cd445 100644 --- a/static/koboldai.js +++ b/static/koboldai.js @@ -25,7 +25,7 @@ socket.on("delete_world_info_entry", function(data){document.getElementById("wor //socket.onAny(function(event_name, data) {console.log({"event": event_name, "class": data.classname, "data": data});}); var backend_vars = {}; -var presets = {} +var presets = {}; var current_chunk_number = null; var ai_busy_start = Date.now(); var popup_deleteable = false;