diff --git a/aiserver.py b/aiserver.py index 216b6b69..c2a273b0 100644 --- a/aiserver.py +++ b/aiserver.py @@ -215,6 +215,19 @@ model_menu = { ["Return to Main Menu", "mainmenu", "", True], ] } + +class TokenStreamQueue: + def __init__(self): + self.probability_buffer = None + self.queue = [] + + def add_text(self, text): + self.queue.append({ + "decoded": text, + "probabilities": self.probability_buffer + }) + self.probability_buffer = None + # Variables class vars: lastact = "" # The last action received from the user @@ -352,7 +365,8 @@ class vars: use_colab_tpu = os.environ.get("COLAB_TPU_ADDR", "") != "" or os.environ.get("TPU_NAME", "") != "" # Whether or not we're in a Colab TPU instance or Kaggle TPU instance and are going to use the TPU rather than the CPU revision = None output_streaming = False - token_stream_queue = [] # Queue for the token streaming + token_stream_queue = TokenStreamQueue() # Queue for the token streaming + show_probs = False # Whether or not to show token probabilities utils.vars = vars @@ -803,6 +817,7 @@ def savesettings(): js["autosave"] = vars.autosave js["welcome"] = vars.welcome js["output_streaming"] = vars.output_streaming + js["show_probs"] = vars.show_probs if(vars.seed_specified): js["seed"] = vars.seed @@ -916,6 +931,8 @@ def processsettings(js): vars.welcome = js["welcome"] if("output_streaming" in js): vars.output_streaming = js["output_streaming"] + if("show_probs" in js): + vars.show_probs = js["show_probs"] if("seed" in js): vars.seed = js["seed"] @@ -955,11 +972,11 @@ def check_for_sp_change(): emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, namespace=None, broadcast=True) vars.sp_changed = False - if(vars.output_streaming and vars.token_stream_queue): + if(vars.token_stream_queue.queue): # If emit blocks, waiting for it to complete before clearing could # introduce a race condition that drops tokens. - queued_tokens = list(vars.token_stream_queue) - vars.token_stream_queue.clear() + queued_tokens = list(vars.token_stream_queue.queue) + vars.token_stream_queue.queue.clear() socketio.emit("from_server", {"cmd": "streamtoken", "data": queued_tokens}, namespace=None, broadcast=True) socketio.start_background_task(check_for_sp_change) @@ -1509,10 +1526,37 @@ def patch_transformers(): assert scores.shape == scores_shape return scores + + from torch.nn import functional as F + + class ProbabilityVisualizerLogitsProcessor(LogitsProcessor): + def __init__(self): + pass + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + assert scores.ndim == 2 + assert input_ids.ndim == 2 + + if vars.numseqs > 1 or not vars.show_probs: + return scores + + probs = F.softmax(scores, dim = -1).cpu().numpy()[0] + + token_prob_info = [] + for token_id, score in sorted(enumerate(probs), key=lambda x: x[1], reverse=True)[:8]: + token_prob_info.append({ + "tokenId": token_id, + "decoded": utils.decodenewlines(tokenizer.decode(token_id)), + "score": float(score), + }) + + vars.token_stream_queue.probability_buffer = token_prob_info + return scores def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList: processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs) processors.insert(0, LuaLogitsProcessor()) + processors.append(ProbabilityVisualizerLogitsProcessor()) return processors new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor @@ -1568,12 +1612,14 @@ def patch_transformers(): **kwargs, ) -> bool: # Do not intermingle multiple generations' outputs! - if(vars.numseqs > 1): + if vars.numseqs > 1: + return False + + if not (vars.show_probs or vars.output_streaming): return False tokenizer_text = utils.decodenewlines(tokenizer.decode(input_ids[0, -1])) - - vars.token_stream_queue.append(tokenizer_text) + vars.token_stream_queue.add_text(tokenizer_text) return False @@ -2760,7 +2806,8 @@ def lua_has_setting(setting): "rmspch", "adsnsp", "singleline", - "output_streaming" + "output_streaming", + "show_probs" ) #==================================================================# @@ -2793,6 +2840,7 @@ def lua_get_setting(setting): if(setting in ("frmtadsnsp", "adsnsp")): return vars.formatoptns["frmtadsnsp"] if(setting in ("frmtsingleline", "singleline")): return vars.formatoptns["singleline"] if(setting == "output_streaming"): return vars.output_streaming + if(setting == "show_probs"): return vars.show_probs #==================================================================# # Set the setting with the given name if it exists @@ -2830,6 +2878,7 @@ def lua_set_setting(setting, v): if(setting in ("frmtadsnsp", "adsnsp")): vars.formatoptns["frmtadsnsp"] = v if(setting in ("frmtsingleline", "singleline")): vars.formatoptns["singleline"] = v if(setting == "output_streaming"): vars.output_streaming = v + if(setting == "show_probs"): vars.show_probs = v #==================================================================# # Get contents of memory @@ -3546,6 +3595,10 @@ def get_message(msg): vars.output_streaming = msg['data'] settingschanged() refresh_settings() + elif(msg['cmd'] == 'setshowprobs'): + vars.show_probs = msg['data'] + settingschanged() + refresh_settings() elif(not vars.host and msg['cmd'] == 'importwi'): wiimportrequest() elif(msg['cmd'] == 'debug'): @@ -4744,6 +4797,7 @@ def refresh_settings(): emit('from_server', {'cmd': 'updatefrmtadsnsp', 'data': vars.formatoptns["frmtadsnsp"]}, broadcast=True) emit('from_server', {'cmd': 'updatesingleline', 'data': vars.formatoptns["singleline"]}, broadcast=True) emit('from_server', {'cmd': 'updateoutputstreaming', 'data': vars.output_streaming}, broadcast=True) + emit('from_server', {'cmd': 'updateshowprobs', 'data': vars.show_probs}, broadcast=True) # Allow toggle events again emit('from_server', {'cmd': 'allowtoggle', 'data': True}, broadcast=True) diff --git a/gensettings.py b/gensettings.py index 5ec022ef..bd644fa8 100644 --- a/gensettings.py +++ b/gensettings.py @@ -263,6 +263,17 @@ gensettingstf = [ "default": 0, "tooltip": "Shows outputs to you as they are made. Does not work with more than one gens per action." }, + { + "uitype": "toggle", + "unit": "bool", + "label": "Show Token Probabilities", + "id": "setshowprobs", + "min": 0, + "max": 1, + "step": 1, + "default": 0, + "tooltip": "Shows token selection probabilities. Does not work with more than one gens per action." + }, ] gensettingsik =[{ diff --git a/static/application.js b/static/application.js index 952a5aa7..12b8f214 100644 --- a/static/application.js +++ b/static/application.js @@ -79,6 +79,7 @@ var rs_close; var seqselmenu; var seqselcontents; var stream_preview; +var token_prob_container; var storyname = null; var memorymode = false; @@ -890,7 +891,7 @@ function formatChunkInnerText(chunk) { } function dosubmit(disallow_abort) { - ignore_stream = false; + beginStream(); submit_start = Date.now(); var txt = input_text.val().replace(/\u00a0/g, " "); if((disallow_abort || gamestate !== "wait") && !memorymode && !gamestarted && ((!adventure || !action_mode) && txt.trim().length == 0)) { @@ -905,7 +906,7 @@ function dosubmit(disallow_abort) { } function _dosubmit() { - ignore_stream = false; + beginStream(); var txt = submit_throttle.txt; var disallow_abort = submit_throttle.disallow_abort; submit_throttle = null; @@ -2088,6 +2089,11 @@ function unbindGametext() { gametext_bound = false; } +function beginStream() { + ignore_stream = false; + token_prob_container[0].innerHTML = ""; +} + function endStream() { // Clear stream, the real text is about to be displayed. ignore_stream = true; @@ -2125,6 +2131,14 @@ function RemoveAllButFirstOption(selectElement) { } } +function interpolateRGB(color0, color1, t) { + return [ + color0[0] + ((color1[0] - color0[0]) * t), + color0[1] + ((color1[1] - color0[1]) * t), + color0[2] + ((color1[2] - color0[2]) * t), + ] +} + //=================================================================// // READY/RUNTIME //=================================================================// @@ -2216,6 +2230,8 @@ $(document).ready(function(){ rs_close = $("#btn_rsclose"); seqselmenu = $("#seqselmenu"); seqselcontents = $("#seqselcontents"); + token_prob_container = $("#token_prob_container"); + token_prob_menu = $("#token_prob_menu"); // Connect to SocketIO server socket = io.connect(window.document.origin, {transports: ['polling', 'websocket'], closeOnBeforeunload: false}); @@ -2279,14 +2295,68 @@ $(document).ready(function(){ // appearing after the output. To combat this, we only allow tokens // to be displayed after requesting and before recieving text. if (ignore_stream) return; - if (!$("#setoutputstreaming")[0].checked) return; - if (!stream_preview) { + let streamingEnabled = $("#setoutputstreaming")[0].checked; + let probabilitiesEnabled = $("#setshowprobs")[0].checked; + + if (!streamingEnabled && !probabilitiesEnabled) return; + + if (!stream_preview && streamingEnabled) { stream_preview = document.createElement("span"); game_text.append(stream_preview); } - stream_preview.innerText += msg.data.join(""); + for (const token of msg.data) { + if (streamingEnabled) stream_preview.innerText += token.decoded; + + if (probabilitiesEnabled) { + // Probability display + let probDiv = document.createElement("div"); + probDiv.classList.add("token-probs"); + + let probTokenSpan = document.createElement("span"); + probTokenSpan.classList.add("token-probs-header"); + probTokenSpan.innerText = token.decoded.replaceAll("\n", "\\n"); + probDiv.appendChild(probTokenSpan); + + let probTable = document.createElement("table"); + let probTBody = document.createElement("tbody"); + probTable.appendChild(probTBody); + + for (const probToken of token.probabilities) { + let tr = document.createElement("tr"); + let rgb = interpolateRGB( + [255, 255, 255], + [0, 255, 0], + probToken.score + ).map(Math.round); + let color = `rgb(${rgb.join(", ")})`; + + if (probToken.decoded === token.decoded) { + tr.classList.add("token-probs-final-token"); + } + + let tds = {}; + + for (const property of ["tokenId", "decoded", "score"]) { + let td = document.createElement("td"); + td.style.color = color; + tds[property] = td; + tr.appendChild(td); + } + + tds.tokenId.innerText = probToken.tokenId; + tds.decoded.innerText = probToken.decoded.toString().replaceAll("\n", "\\n"); + tds.score.innerText = (probToken.score * 100).toFixed(2) + "%"; + + probTBody.appendChild(tr); + } + + probDiv.appendChild(probTable); + token_prob_container.append(probDiv); + } + } + scrollToBottom(); } else if(msg.cmd == "updatescreen") { var _gamestarted = gamestarted; @@ -2561,6 +2631,14 @@ $(document).ready(function(){ } else if(msg.cmd == "updateoutputstreaming") { // Update toggle state $("#setoutputstreaming").prop('checked', msg.data).change(); + } else if(msg.cmd == "updateshowprobs") { + $("#setshowprobs").prop('checked', msg.data).change(); + + if(msg.data) { + token_prob_menu.removeClass("hidden"); + } else { + token_prob_menu.addClass("hidden"); + } } else if(msg.cmd == "allowtoggle") { // Allow toggle change states to propagate allowtoggle = msg.data; @@ -2956,7 +3034,7 @@ $(document).ready(function(){ }); button_actretry.on("click", function(ev) { - ignore_stream = false; + beginStream(); hideMessage(); socket.send({'cmd': 'retry', 'chatname': chatmode ? chat_name.val() : undefined, 'data': ''}); hidegenseqs(); @@ -3203,7 +3281,7 @@ $(document).ready(function(){ }); rs_accept.on("click", function(ev) { - ignore_stream = false; + beginStream(); hideMessage(); socket.send({'cmd': 'rndgame', 'memory': $("#rngmemory").val(), 'data': topic.val()}); hideRandomStoryPopup(); diff --git a/static/custom.css b/static/custom.css index 10ed2dc6..fec9c506 100644 --- a/static/custom.css +++ b/static/custom.css @@ -1647,4 +1647,51 @@ body.connected .popupfooter, .popupfooter.always-available { .breadcrumbitem:hover { cursor: pointer; background-color: #688f1f; -} \ No newline at end of file +} + +#token_prob_menu { + color: white; + background-color: #262626; +} + +.token-probs { + display: inline-block; + text-align: center; + margin-right: 5px; +} + +.token-probs > table { + width: 100%; +} + +.token-probs > table > tbody > tr > td { + border: 1px solid #262626; + border-collapse: collapse; + padding: 2px 15px; +} + +.token-probs > table > tbody > tr { + background-color: #3e3e3e; +} + +.token-probs > table > tbody > tr:nth-child(2n) { + background-color: #575757; +} + +.token-probs-final-token { + font-weight: bold; + text-decoration: underline; +} + +.token-probs-final-token > td { + background: #5c8a5a; +} + +.token-probs-header { + display: block; +} + +#token_prob_container { + overflow-x: scroll; + white-space: nowrap; +} diff --git a/templates/index.html b/templates/index.html index 9ea89205..11833ad0 100644 --- a/templates/index.html +++ b/templates/index.html @@ -123,6 +123,11 @@
+ + +...