From a4d81292f839f6e38fd95428ef8f14bfb9fe5270 Mon Sep 17 00:00:00 2001 From: somebody Date: Wed, 27 Jul 2022 22:13:08 -0500 Subject: [PATCH] Add token streaming option --- aiserver.py | 47 ++++++++++++++++++++++++++++++++++++++++++- gensettings.py | 13 +++++++++++- static/application.js | 34 +++++++++++++++++++++++++++++++ 3 files changed, 92 insertions(+), 2 deletions(-) diff --git a/aiserver.py b/aiserver.py index 52cd7b28..06d23f17 100644 --- a/aiserver.py +++ b/aiserver.py @@ -351,6 +351,8 @@ class vars: lazy_load = True # Whether or not to use torch_lazy_loader.py for transformers models in order to reduce CPU memory usage use_colab_tpu = os.environ.get("COLAB_TPU_ADDR", "") != "" or os.environ.get("TPU_NAME", "") != "" # Whether or not we're in a Colab TPU instance or Kaggle TPU instance and are going to use the TPU rather than the CPU revision = None + output_streaming = False + token_stream_queue = [] # Queue for the token streaming utils.vars = vars @@ -800,6 +802,7 @@ def savesettings(): js["fulldeterminism"] = vars.full_determinism js["autosave"] = vars.autosave js["welcome"] = vars.welcome + js["output_streaming"] = vars.output_streaming if(vars.seed_specified): js["seed"] = vars.seed @@ -911,6 +914,8 @@ def processsettings(js): vars.newlinemode = js["newlinemode"] if("welcome" in js): vars.welcome = js["welcome"] + if("output_streaming" in js): + vars.output_streaming = js["output_streaming"] if("seed" in js): vars.seed = js["seed"] @@ -943,12 +948,20 @@ def processsettings(js): def check_for_sp_change(): while(True): - time.sleep(0.1) + time.sleep(0.05) + if(vars.sp_changed): with app.app_context(): emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, namespace=None, broadcast=True) vars.sp_changed = False + if(vars.output_streaming and vars.token_stream_queue): + # If emit blocks, waiting for it to complete before clearing could + # introduce a race condition that drops tokens. + queued_tokens = list(vars.token_stream_queue) + vars.token_stream_queue.clear() + socketio.emit("from_server", {"cmd": "streamtoken", "data": queued_tokens}, namespace=None, broadcast=True) + socketio.start_background_task(check_for_sp_change) def spRequest(filename): @@ -1541,6 +1554,27 @@ def patch_transformers(): new_init.old_init = transformers.generation_logits_process.NoBadWordsLogitsProcessor.__init__ transformers.generation_logits_process.NoBadWordsLogitsProcessor.__init__ = new_init + class TokenStreamer(StoppingCriteria): + # A StoppingCriteria is used here because it seems to run after + # everything has been evaluated score-wise. + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def __call__( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + **kwargs, + ) -> bool: + # Do not intermingle multiple generations' outputs! + if(vars.numseqs > 1): + return False + + tokenizer_text = utils.decodenewlines(tokenizer.decode(input_ids[0, -1])) + + vars.token_stream_queue.append(tokenizer_text) + return False + # Sets up dynamic world info scanner class DynamicWorldInfoScanCriteria(StoppingCriteria): @@ -1595,7 +1629,10 @@ def patch_transformers(): tokenizer=tokenizer, excluded_world_info=self.kai_scanner_excluded_world_info, ) + token_streamer = TokenStreamer(tokenizer=tokenizer) + stopping_criteria.insert(0, self.kai_scanner) + stopping_criteria.insert(0, token_streamer) return stopping_criteria transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria @@ -2697,6 +2734,7 @@ def lua_has_setting(setting): "rmspch", "adsnsp", "singleline", + "output_streaming" ) #==================================================================# @@ -2728,6 +2766,7 @@ def lua_get_setting(setting): if(setting in ("frmtrmspch", "rmspch")): return vars.formatoptns["frmttrmspch"] if(setting in ("frmtadsnsp", "adsnsp")): return vars.formatoptns["frmtadsnsp"] if(setting in ("frmtsingleline", "singleline")): return vars.formatoptns["singleline"] + if(setting == "output_streaming"): return vars.output_streaming #==================================================================# # Set the setting with the given name if it exists @@ -2764,6 +2803,7 @@ def lua_set_setting(setting, v): if(setting in ("frmtrmspch", "rmspch")): vars.formatoptns["frmttrmspch"] = v if(setting in ("frmtadsnsp", "adsnsp")): vars.formatoptns["frmtadsnsp"] = v if(setting in ("frmtsingleline", "singleline")): vars.formatoptns["singleline"] = v + if(setting == "output_streaming"): vars.output_streaming = v #==================================================================# # Get contents of memory @@ -3476,6 +3516,10 @@ def get_message(msg): vars.full_determinism = msg['data'] settingschanged() refresh_settings() + elif(msg['cmd'] == 'setoutputstreaming'): + vars.output_streaming = msg['data'] + settingschanged() + refresh_settings() elif(not vars.host and msg['cmd'] == 'importwi'): wiimportrequest() elif(msg['cmd'] == 'debug'): @@ -4673,6 +4717,7 @@ def refresh_settings(): emit('from_server', {'cmd': 'updatefrmtrmspch', 'data': vars.formatoptns["frmtrmspch"]}, broadcast=True) emit('from_server', {'cmd': 'updatefrmtadsnsp', 'data': vars.formatoptns["frmtadsnsp"]}, broadcast=True) emit('from_server', {'cmd': 'updatesingleline', 'data': vars.formatoptns["singleline"]}, broadcast=True) + emit('from_server', {'cmd': 'updateoutputstreaming', 'data': vars.output_streaming}, broadcast=True) # Allow toggle events again emit('from_server', {'cmd': 'allowtoggle', 'data': True}, broadcast=True) diff --git a/gensettings.py b/gensettings.py index 3fdb7669..5ec022ef 100644 --- a/gensettings.py +++ b/gensettings.py @@ -251,7 +251,18 @@ gensettingstf = [ "step": 1, "default": 0, "tooltip": "Show debug info" - } + }, + { + "uitype": "toggle", + "unit": "bool", + "label": "Token Streaming", + "id": "setoutputstreaming", + "min": 0, + "max": 1, + "step": 1, + "default": 0, + "tooltip": "Shows outputs to you as they are made. Does not work with more than one gens per action." + }, ] gensettingsik =[{ diff --git a/static/application.js b/static/application.js index 2388aa23..7475fea2 100644 --- a/static/application.js +++ b/static/application.js @@ -78,6 +78,7 @@ var rs_accept; var rs_close; var seqselmenu; var seqselcontents; +var stream_preview; var storyname = null; var memorymode = false; @@ -103,6 +104,7 @@ var gamestate = ""; var gamesaved = true; var modelname = null; var model = ""; +var ignore_stream = false; // This is true iff [we're in macOS and the browser is Safari] or [we're in iOS] var using_webkit_patch = true; @@ -888,6 +890,7 @@ function formatChunkInnerText(chunk) { } function dosubmit(disallow_abort) { + ignore_stream = false; submit_start = Date.now(); var txt = input_text.val().replace(/\u00a0/g, " "); if((disallow_abort || gamestate !== "wait") && !memorymode && !gamestarted && ((!adventure || !action_mode) && txt.trim().length == 0)) { @@ -902,6 +905,7 @@ function dosubmit(disallow_abort) { } function _dosubmit() { + ignore_stream = false; var txt = submit_throttle.txt; var disallow_abort = submit_throttle.disallow_abort; submit_throttle = null; @@ -2082,6 +2086,15 @@ function unbindGametext() { gametext_bound = false; } +function endStream() { + // Clear stream, the real text is about to be displayed. + ignore_stream = true; + if (stream_preview) { + stream_preview.remove(); + stream_preview = null; + } +} + function update_gpu_layers() { var gpu_layers gpu_layers = 0; @@ -2258,6 +2271,21 @@ $(document).ready(function(){ active_element.focus(); })(); $("body").addClass("connected"); + } else if (msg.cmd == "streamtoken") { + // Sometimes the stream_token messages will come in too late, after + // we have recieved the full text. This leads to some stray tokens + // appearing after the output. To combat this, we only allow tokens + // to be displayed after requesting and before recieving text. + if (ignore_stream) return; + if (!$("#setoutputstreaming")[0].checked) return; + + if (!stream_preview) { + stream_preview = document.createElement("span"); + game_text.append(stream_preview); + } + + stream_preview.innerText += msg.data.join(""); + scrollToBottom(); } else if(msg.cmd == "updatescreen") { var _gamestarted = gamestarted; gamestarted = msg.gamestarted; @@ -2333,6 +2361,7 @@ $(document).ready(function(){ } else if(msg.cmd == "setgamestate") { // Enable or Disable buttons if(msg.data == "ready") { + endStream(); enableSendBtn(); enableButtons([button_actmem, button_actwi, button_actback, button_actfwd, button_actretry]); hideWaitAnimation(); @@ -2519,6 +2548,9 @@ $(document).ready(function(){ } else if(msg.cmd == "updatesingleline") { // Update toggle state $("#singleline").prop('checked', msg.data).change(); + } else if(msg.cmd == "updateoutputstreaming") { + // Update toggle state + $("#setoutputstreaming").prop('checked', msg.data).change(); } else if(msg.cmd == "allowtoggle") { // Allow toggle change states to propagate allowtoggle = msg.data; @@ -2914,6 +2946,7 @@ $(document).ready(function(){ }); button_actretry.on("click", function(ev) { + ignore_stream = false; hideMessage(); socket.send({'cmd': 'retry', 'chatname': chatmode ? chat_name.val() : undefined, 'data': ''}); hidegenseqs(); @@ -3160,6 +3193,7 @@ $(document).ready(function(){ }); rs_accept.on("click", function(ev) { + ignore_stream = false; hideMessage(); socket.send({'cmd': 'rndgame', 'memory': $("#rngmemory").val(), 'data': topic.val()}); hideRandomStoryPopup();