Merge pull request #178 from one-some/token-prob

Add token probability visualizer
This commit is contained in:
henk717 2022-08-05 14:27:46 +02:00 committed by GitHub
commit 8bcf4187ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 211 additions and 16 deletions

View File

@ -215,6 +215,19 @@ model_menu = {
["Return to Main Menu", "mainmenu", "", True],
]
}
class TokenStreamQueue:
def __init__(self):
self.probability_buffer = None
self.queue = []
def add_text(self, text):
self.queue.append({
"decoded": text,
"probabilities": self.probability_buffer
})
self.probability_buffer = None
# Variables
class vars:
lastact = "" # The last action received from the user
@ -352,7 +365,8 @@ class vars:
use_colab_tpu = os.environ.get("COLAB_TPU_ADDR", "") != "" or os.environ.get("TPU_NAME", "") != "" # Whether or not we're in a Colab TPU instance or Kaggle TPU instance and are going to use the TPU rather than the CPU
revision = None
output_streaming = False
token_stream_queue = [] # Queue for the token streaming
token_stream_queue = TokenStreamQueue() # Queue for the token streaming
show_probs = False # Whether or not to show token probabilities
utils.vars = vars
@ -803,6 +817,7 @@ def savesettings():
js["autosave"] = vars.autosave
js["welcome"] = vars.welcome
js["output_streaming"] = vars.output_streaming
js["show_probs"] = vars.show_probs
if(vars.seed_specified):
js["seed"] = vars.seed
@ -916,6 +931,8 @@ def processsettings(js):
vars.welcome = js["welcome"]
if("output_streaming" in js):
vars.output_streaming = js["output_streaming"]
if("show_probs" in js):
vars.show_probs = js["show_probs"]
if("seed" in js):
vars.seed = js["seed"]
@ -955,11 +972,11 @@ def check_for_sp_change():
emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, namespace=None, broadcast=True)
vars.sp_changed = False
if(vars.output_streaming and vars.token_stream_queue):
if(vars.token_stream_queue.queue):
# If emit blocks, waiting for it to complete before clearing could
# introduce a race condition that drops tokens.
queued_tokens = list(vars.token_stream_queue)
vars.token_stream_queue.clear()
queued_tokens = list(vars.token_stream_queue.queue)
vars.token_stream_queue.queue.clear()
socketio.emit("from_server", {"cmd": "streamtoken", "data": queued_tokens}, namespace=None, broadcast=True)
socketio.start_background_task(check_for_sp_change)
@ -1509,10 +1526,37 @@ def patch_transformers():
assert scores.shape == scores_shape
return scores
from torch.nn import functional as F
class ProbabilityVisualizerLogitsProcessor(LogitsProcessor):
def __init__(self):
pass
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
assert scores.ndim == 2
assert input_ids.ndim == 2
if vars.numseqs > 1 or not vars.show_probs:
return scores
probs = F.softmax(scores, dim = -1).cpu().numpy()[0]
token_prob_info = []
for token_id, score in sorted(enumerate(probs), key=lambda x: x[1], reverse=True)[:8]:
token_prob_info.append({
"tokenId": token_id,
"decoded": utils.decodenewlines(tokenizer.decode(token_id)),
"score": float(score),
})
vars.token_stream_queue.probability_buffer = token_prob_info
return scores
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs)
processors.insert(0, LuaLogitsProcessor())
processors.append(ProbabilityVisualizerLogitsProcessor())
return processors
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
@ -1568,12 +1612,14 @@ def patch_transformers():
**kwargs,
) -> bool:
# Do not intermingle multiple generations' outputs!
if(vars.numseqs > 1):
if vars.numseqs > 1:
return False
if not (vars.show_probs or vars.output_streaming):
return False
tokenizer_text = utils.decodenewlines(tokenizer.decode(input_ids[0, -1]))
vars.token_stream_queue.append(tokenizer_text)
vars.token_stream_queue.add_text(tokenizer_text)
return False
@ -2760,7 +2806,8 @@ def lua_has_setting(setting):
"rmspch",
"adsnsp",
"singleline",
"output_streaming"
"output_streaming",
"show_probs"
)
#==================================================================#
@ -2793,6 +2840,7 @@ def lua_get_setting(setting):
if(setting in ("frmtadsnsp", "adsnsp")): return vars.formatoptns["frmtadsnsp"]
if(setting in ("frmtsingleline", "singleline")): return vars.formatoptns["singleline"]
if(setting == "output_streaming"): return vars.output_streaming
if(setting == "show_probs"): return vars.show_probs
#==================================================================#
# Set the setting with the given name if it exists
@ -2830,6 +2878,7 @@ def lua_set_setting(setting, v):
if(setting in ("frmtadsnsp", "adsnsp")): vars.formatoptns["frmtadsnsp"] = v
if(setting in ("frmtsingleline", "singleline")): vars.formatoptns["singleline"] = v
if(setting == "output_streaming"): vars.output_streaming = v
if(setting == "show_probs"): vars.show_probs = v
#==================================================================#
# Get contents of memory
@ -3546,6 +3595,10 @@ def get_message(msg):
vars.output_streaming = msg['data']
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setshowprobs'):
vars.show_probs = msg['data']
settingschanged()
refresh_settings()
elif(not vars.host and msg['cmd'] == 'importwi'):
wiimportrequest()
elif(msg['cmd'] == 'debug'):
@ -4744,6 +4797,7 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatefrmtadsnsp', 'data': vars.formatoptns["frmtadsnsp"]}, broadcast=True)
emit('from_server', {'cmd': 'updatesingleline', 'data': vars.formatoptns["singleline"]}, broadcast=True)
emit('from_server', {'cmd': 'updateoutputstreaming', 'data': vars.output_streaming}, broadcast=True)
emit('from_server', {'cmd': 'updateshowprobs', 'data': vars.show_probs}, broadcast=True)
# Allow toggle events again
emit('from_server', {'cmd': 'allowtoggle', 'data': True}, broadcast=True)

View File

@ -263,6 +263,17 @@ gensettingstf = [
"default": 0,
"tooltip": "Shows outputs to you as they are made. Does not work with more than one gens per action."
},
{
"uitype": "toggle",
"unit": "bool",
"label": "Show Token Probabilities",
"id": "setshowprobs",
"min": 0,
"max": 1,
"step": 1,
"default": 0,
"tooltip": "Shows token selection probabilities. Does not work with more than one gens per action."
},
]
gensettingsik =[{

View File

@ -79,6 +79,7 @@ var rs_close;
var seqselmenu;
var seqselcontents;
var stream_preview;
var token_prob_container;
var storyname = null;
var memorymode = false;
@ -890,7 +891,7 @@ function formatChunkInnerText(chunk) {
}
function dosubmit(disallow_abort) {
ignore_stream = false;
beginStream();
submit_start = Date.now();
var txt = input_text.val().replace(/\u00a0/g, " ");
if((disallow_abort || gamestate !== "wait") && !memorymode && !gamestarted && ((!adventure || !action_mode) && txt.trim().length == 0)) {
@ -905,7 +906,7 @@ function dosubmit(disallow_abort) {
}
function _dosubmit() {
ignore_stream = false;
beginStream();
var txt = submit_throttle.txt;
var disallow_abort = submit_throttle.disallow_abort;
submit_throttle = null;
@ -2088,6 +2089,11 @@ function unbindGametext() {
gametext_bound = false;
}
function beginStream() {
ignore_stream = false;
token_prob_container[0].innerHTML = "";
}
function endStream() {
// Clear stream, the real text is about to be displayed.
ignore_stream = true;
@ -2125,6 +2131,14 @@ function RemoveAllButFirstOption(selectElement) {
}
}
function interpolateRGB(color0, color1, t) {
return [
color0[0] + ((color1[0] - color0[0]) * t),
color0[1] + ((color1[1] - color0[1]) * t),
color0[2] + ((color1[2] - color0[2]) * t),
]
}
//=================================================================//
// READY/RUNTIME
//=================================================================//
@ -2216,6 +2230,8 @@ $(document).ready(function(){
rs_close = $("#btn_rsclose");
seqselmenu = $("#seqselmenu");
seqselcontents = $("#seqselcontents");
token_prob_container = $("#token_prob_container");
token_prob_menu = $("#token_prob_menu");
// Connect to SocketIO server
socket = io.connect(window.document.origin, {transports: ['polling', 'websocket'], closeOnBeforeunload: false});
@ -2279,14 +2295,68 @@ $(document).ready(function(){
// appearing after the output. To combat this, we only allow tokens
// to be displayed after requesting and before recieving text.
if (ignore_stream) return;
if (!$("#setoutputstreaming")[0].checked) return;
if (!stream_preview) {
let streamingEnabled = $("#setoutputstreaming")[0].checked;
let probabilitiesEnabled = $("#setshowprobs")[0].checked;
if (!streamingEnabled && !probabilitiesEnabled) return;
if (!stream_preview && streamingEnabled) {
stream_preview = document.createElement("span");
game_text.append(stream_preview);
}
stream_preview.innerText += msg.data.join("");
for (const token of msg.data) {
if (streamingEnabled) stream_preview.innerText += token.decoded;
if (probabilitiesEnabled) {
// Probability display
let probDiv = document.createElement("div");
probDiv.classList.add("token-probs");
let probTokenSpan = document.createElement("span");
probTokenSpan.classList.add("token-probs-header");
probTokenSpan.innerText = token.decoded.replaceAll("\n", "\\n");
probDiv.appendChild(probTokenSpan);
let probTable = document.createElement("table");
let probTBody = document.createElement("tbody");
probTable.appendChild(probTBody);
for (const probToken of token.probabilities) {
let tr = document.createElement("tr");
let rgb = interpolateRGB(
[255, 255, 255],
[0, 255, 0],
probToken.score
).map(Math.round);
let color = `rgb(${rgb.join(", ")})`;
if (probToken.decoded === token.decoded) {
tr.classList.add("token-probs-final-token");
}
let tds = {};
for (const property of ["tokenId", "decoded", "score"]) {
let td = document.createElement("td");
td.style.color = color;
tds[property] = td;
tr.appendChild(td);
}
tds.tokenId.innerText = probToken.tokenId;
tds.decoded.innerText = probToken.decoded.toString().replaceAll("\n", "\\n");
tds.score.innerText = (probToken.score * 100).toFixed(2) + "%";
probTBody.appendChild(tr);
}
probDiv.appendChild(probTable);
token_prob_container.append(probDiv);
}
}
scrollToBottom();
} else if(msg.cmd == "updatescreen") {
var _gamestarted = gamestarted;
@ -2561,6 +2631,14 @@ $(document).ready(function(){
} else if(msg.cmd == "updateoutputstreaming") {
// Update toggle state
$("#setoutputstreaming").prop('checked', msg.data).change();
} else if(msg.cmd == "updateshowprobs") {
$("#setshowprobs").prop('checked', msg.data).change();
if(msg.data) {
token_prob_menu.removeClass("hidden");
} else {
token_prob_menu.addClass("hidden");
}
} else if(msg.cmd == "allowtoggle") {
// Allow toggle change states to propagate
allowtoggle = msg.data;
@ -2956,7 +3034,7 @@ $(document).ready(function(){
});
button_actretry.on("click", function(ev) {
ignore_stream = false;
beginStream();
hideMessage();
socket.send({'cmd': 'retry', 'chatname': chatmode ? chat_name.val() : undefined, 'data': ''});
hidegenseqs();
@ -3203,7 +3281,7 @@ $(document).ready(function(){
});
rs_accept.on("click", function(ev) {
ignore_stream = false;
beginStream();
hideMessage();
socket.send({'cmd': 'rndgame', 'memory': $("#rngmemory").val(), 'data': topic.val()});
hideRandomStoryPopup();

View File

@ -1647,4 +1647,51 @@ body.connected .popupfooter, .popupfooter.always-available {
.breadcrumbitem:hover {
cursor: pointer;
background-color: #688f1f;
}
}
#token_prob_menu {
color: white;
background-color: #262626;
}
.token-probs {
display: inline-block;
text-align: center;
margin-right: 5px;
}
.token-probs > table {
width: 100%;
}
.token-probs > table > tbody > tr > td {
border: 1px solid #262626;
border-collapse: collapse;
padding: 2px 15px;
}
.token-probs > table > tbody > tr {
background-color: #3e3e3e;
}
.token-probs > table > tbody > tr:nth-child(2n) {
background-color: #575757;
}
.token-probs-final-token {
font-weight: bold;
text-decoration: underline;
}
.token-probs-final-token > td {
background: #5c8a5a;
}
.token-probs-header {
display: block;
}
#token_prob_container {
overflow-x: scroll;
white-space: nowrap;
}

View File

@ -123,6 +123,11 @@
</div>
<div class="row" id="formatmenu">
</div>
<div id="token_prob_menu" class="row hidden">
<div id="token_prob_container"></div>
</div>
<div class="layer-container">
<div class="layer-bottom row" id="gamescreen">
<span id="gametext" contenteditable="true"><p>...</p></span>