Merge pull request #178 from one-some/token-prob
Add token probability visualizer
This commit is contained in:
commit
8bcf4187ac
70
aiserver.py
70
aiserver.py
|
@ -215,6 +215,19 @@ model_menu = {
|
|||
["Return to Main Menu", "mainmenu", "", True],
|
||||
]
|
||||
}
|
||||
|
||||
class TokenStreamQueue:
|
||||
def __init__(self):
|
||||
self.probability_buffer = None
|
||||
self.queue = []
|
||||
|
||||
def add_text(self, text):
|
||||
self.queue.append({
|
||||
"decoded": text,
|
||||
"probabilities": self.probability_buffer
|
||||
})
|
||||
self.probability_buffer = None
|
||||
|
||||
# Variables
|
||||
class vars:
|
||||
lastact = "" # The last action received from the user
|
||||
|
@ -352,7 +365,8 @@ class vars:
|
|||
use_colab_tpu = os.environ.get("COLAB_TPU_ADDR", "") != "" or os.environ.get("TPU_NAME", "") != "" # Whether or not we're in a Colab TPU instance or Kaggle TPU instance and are going to use the TPU rather than the CPU
|
||||
revision = None
|
||||
output_streaming = False
|
||||
token_stream_queue = [] # Queue for the token streaming
|
||||
token_stream_queue = TokenStreamQueue() # Queue for the token streaming
|
||||
show_probs = False # Whether or not to show token probabilities
|
||||
|
||||
utils.vars = vars
|
||||
|
||||
|
@ -803,6 +817,7 @@ def savesettings():
|
|||
js["autosave"] = vars.autosave
|
||||
js["welcome"] = vars.welcome
|
||||
js["output_streaming"] = vars.output_streaming
|
||||
js["show_probs"] = vars.show_probs
|
||||
|
||||
if(vars.seed_specified):
|
||||
js["seed"] = vars.seed
|
||||
|
@ -916,6 +931,8 @@ def processsettings(js):
|
|||
vars.welcome = js["welcome"]
|
||||
if("output_streaming" in js):
|
||||
vars.output_streaming = js["output_streaming"]
|
||||
if("show_probs" in js):
|
||||
vars.show_probs = js["show_probs"]
|
||||
|
||||
if("seed" in js):
|
||||
vars.seed = js["seed"]
|
||||
|
@ -955,11 +972,11 @@ def check_for_sp_change():
|
|||
emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, namespace=None, broadcast=True)
|
||||
vars.sp_changed = False
|
||||
|
||||
if(vars.output_streaming and vars.token_stream_queue):
|
||||
if(vars.token_stream_queue.queue):
|
||||
# If emit blocks, waiting for it to complete before clearing could
|
||||
# introduce a race condition that drops tokens.
|
||||
queued_tokens = list(vars.token_stream_queue)
|
||||
vars.token_stream_queue.clear()
|
||||
queued_tokens = list(vars.token_stream_queue.queue)
|
||||
vars.token_stream_queue.queue.clear()
|
||||
socketio.emit("from_server", {"cmd": "streamtoken", "data": queued_tokens}, namespace=None, broadcast=True)
|
||||
|
||||
socketio.start_background_task(check_for_sp_change)
|
||||
|
@ -1509,10 +1526,37 @@ def patch_transformers():
|
|||
assert scores.shape == scores_shape
|
||||
|
||||
return scores
|
||||
|
||||
from torch.nn import functional as F
|
||||
|
||||
class ProbabilityVisualizerLogitsProcessor(LogitsProcessor):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
assert scores.ndim == 2
|
||||
assert input_ids.ndim == 2
|
||||
|
||||
if vars.numseqs > 1 or not vars.show_probs:
|
||||
return scores
|
||||
|
||||
probs = F.softmax(scores, dim = -1).cpu().numpy()[0]
|
||||
|
||||
token_prob_info = []
|
||||
for token_id, score in sorted(enumerate(probs), key=lambda x: x[1], reverse=True)[:8]:
|
||||
token_prob_info.append({
|
||||
"tokenId": token_id,
|
||||
"decoded": utils.decodenewlines(tokenizer.decode(token_id)),
|
||||
"score": float(score),
|
||||
})
|
||||
|
||||
vars.token_stream_queue.probability_buffer = token_prob_info
|
||||
return scores
|
||||
|
||||
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
|
||||
processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs)
|
||||
processors.insert(0, LuaLogitsProcessor())
|
||||
processors.append(ProbabilityVisualizerLogitsProcessor())
|
||||
return processors
|
||||
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
|
||||
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
|
||||
|
@ -1568,12 +1612,14 @@ def patch_transformers():
|
|||
**kwargs,
|
||||
) -> bool:
|
||||
# Do not intermingle multiple generations' outputs!
|
||||
if(vars.numseqs > 1):
|
||||
if vars.numseqs > 1:
|
||||
return False
|
||||
|
||||
if not (vars.show_probs or vars.output_streaming):
|
||||
return False
|
||||
|
||||
tokenizer_text = utils.decodenewlines(tokenizer.decode(input_ids[0, -1]))
|
||||
|
||||
vars.token_stream_queue.append(tokenizer_text)
|
||||
vars.token_stream_queue.add_text(tokenizer_text)
|
||||
return False
|
||||
|
||||
|
||||
|
@ -2760,7 +2806,8 @@ def lua_has_setting(setting):
|
|||
"rmspch",
|
||||
"adsnsp",
|
||||
"singleline",
|
||||
"output_streaming"
|
||||
"output_streaming",
|
||||
"show_probs"
|
||||
)
|
||||
|
||||
#==================================================================#
|
||||
|
@ -2793,6 +2840,7 @@ def lua_get_setting(setting):
|
|||
if(setting in ("frmtadsnsp", "adsnsp")): return vars.formatoptns["frmtadsnsp"]
|
||||
if(setting in ("frmtsingleline", "singleline")): return vars.formatoptns["singleline"]
|
||||
if(setting == "output_streaming"): return vars.output_streaming
|
||||
if(setting == "show_probs"): return vars.show_probs
|
||||
|
||||
#==================================================================#
|
||||
# Set the setting with the given name if it exists
|
||||
|
@ -2830,6 +2878,7 @@ def lua_set_setting(setting, v):
|
|||
if(setting in ("frmtadsnsp", "adsnsp")): vars.formatoptns["frmtadsnsp"] = v
|
||||
if(setting in ("frmtsingleline", "singleline")): vars.formatoptns["singleline"] = v
|
||||
if(setting == "output_streaming"): vars.output_streaming = v
|
||||
if(setting == "show_probs"): vars.show_probs = v
|
||||
|
||||
#==================================================================#
|
||||
# Get contents of memory
|
||||
|
@ -3546,6 +3595,10 @@ def get_message(msg):
|
|||
vars.output_streaming = msg['data']
|
||||
settingschanged()
|
||||
refresh_settings()
|
||||
elif(msg['cmd'] == 'setshowprobs'):
|
||||
vars.show_probs = msg['data']
|
||||
settingschanged()
|
||||
refresh_settings()
|
||||
elif(not vars.host and msg['cmd'] == 'importwi'):
|
||||
wiimportrequest()
|
||||
elif(msg['cmd'] == 'debug'):
|
||||
|
@ -4744,6 +4797,7 @@ def refresh_settings():
|
|||
emit('from_server', {'cmd': 'updatefrmtadsnsp', 'data': vars.formatoptns["frmtadsnsp"]}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatesingleline', 'data': vars.formatoptns["singleline"]}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updateoutputstreaming', 'data': vars.output_streaming}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updateshowprobs', 'data': vars.show_probs}, broadcast=True)
|
||||
|
||||
# Allow toggle events again
|
||||
emit('from_server', {'cmd': 'allowtoggle', 'data': True}, broadcast=True)
|
||||
|
|
|
@ -263,6 +263,17 @@ gensettingstf = [
|
|||
"default": 0,
|
||||
"tooltip": "Shows outputs to you as they are made. Does not work with more than one gens per action."
|
||||
},
|
||||
{
|
||||
"uitype": "toggle",
|
||||
"unit": "bool",
|
||||
"label": "Show Token Probabilities",
|
||||
"id": "setshowprobs",
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
"step": 1,
|
||||
"default": 0,
|
||||
"tooltip": "Shows token selection probabilities. Does not work with more than one gens per action."
|
||||
},
|
||||
]
|
||||
|
||||
gensettingsik =[{
|
||||
|
|
|
@ -79,6 +79,7 @@ var rs_close;
|
|||
var seqselmenu;
|
||||
var seqselcontents;
|
||||
var stream_preview;
|
||||
var token_prob_container;
|
||||
|
||||
var storyname = null;
|
||||
var memorymode = false;
|
||||
|
@ -890,7 +891,7 @@ function formatChunkInnerText(chunk) {
|
|||
}
|
||||
|
||||
function dosubmit(disallow_abort) {
|
||||
ignore_stream = false;
|
||||
beginStream();
|
||||
submit_start = Date.now();
|
||||
var txt = input_text.val().replace(/\u00a0/g, " ");
|
||||
if((disallow_abort || gamestate !== "wait") && !memorymode && !gamestarted && ((!adventure || !action_mode) && txt.trim().length == 0)) {
|
||||
|
@ -905,7 +906,7 @@ function dosubmit(disallow_abort) {
|
|||
}
|
||||
|
||||
function _dosubmit() {
|
||||
ignore_stream = false;
|
||||
beginStream();
|
||||
var txt = submit_throttle.txt;
|
||||
var disallow_abort = submit_throttle.disallow_abort;
|
||||
submit_throttle = null;
|
||||
|
@ -2088,6 +2089,11 @@ function unbindGametext() {
|
|||
gametext_bound = false;
|
||||
}
|
||||
|
||||
function beginStream() {
|
||||
ignore_stream = false;
|
||||
token_prob_container[0].innerHTML = "";
|
||||
}
|
||||
|
||||
function endStream() {
|
||||
// Clear stream, the real text is about to be displayed.
|
||||
ignore_stream = true;
|
||||
|
@ -2125,6 +2131,14 @@ function RemoveAllButFirstOption(selectElement) {
|
|||
}
|
||||
}
|
||||
|
||||
function interpolateRGB(color0, color1, t) {
|
||||
return [
|
||||
color0[0] + ((color1[0] - color0[0]) * t),
|
||||
color0[1] + ((color1[1] - color0[1]) * t),
|
||||
color0[2] + ((color1[2] - color0[2]) * t),
|
||||
]
|
||||
}
|
||||
|
||||
//=================================================================//
|
||||
// READY/RUNTIME
|
||||
//=================================================================//
|
||||
|
@ -2216,6 +2230,8 @@ $(document).ready(function(){
|
|||
rs_close = $("#btn_rsclose");
|
||||
seqselmenu = $("#seqselmenu");
|
||||
seqselcontents = $("#seqselcontents");
|
||||
token_prob_container = $("#token_prob_container");
|
||||
token_prob_menu = $("#token_prob_menu");
|
||||
|
||||
// Connect to SocketIO server
|
||||
socket = io.connect(window.document.origin, {transports: ['polling', 'websocket'], closeOnBeforeunload: false});
|
||||
|
@ -2279,14 +2295,68 @@ $(document).ready(function(){
|
|||
// appearing after the output. To combat this, we only allow tokens
|
||||
// to be displayed after requesting and before recieving text.
|
||||
if (ignore_stream) return;
|
||||
if (!$("#setoutputstreaming")[0].checked) return;
|
||||
|
||||
if (!stream_preview) {
|
||||
let streamingEnabled = $("#setoutputstreaming")[0].checked;
|
||||
let probabilitiesEnabled = $("#setshowprobs")[0].checked;
|
||||
|
||||
if (!streamingEnabled && !probabilitiesEnabled) return;
|
||||
|
||||
if (!stream_preview && streamingEnabled) {
|
||||
stream_preview = document.createElement("span");
|
||||
game_text.append(stream_preview);
|
||||
}
|
||||
|
||||
stream_preview.innerText += msg.data.join("");
|
||||
for (const token of msg.data) {
|
||||
if (streamingEnabled) stream_preview.innerText += token.decoded;
|
||||
|
||||
if (probabilitiesEnabled) {
|
||||
// Probability display
|
||||
let probDiv = document.createElement("div");
|
||||
probDiv.classList.add("token-probs");
|
||||
|
||||
let probTokenSpan = document.createElement("span");
|
||||
probTokenSpan.classList.add("token-probs-header");
|
||||
probTokenSpan.innerText = token.decoded.replaceAll("\n", "\\n");
|
||||
probDiv.appendChild(probTokenSpan);
|
||||
|
||||
let probTable = document.createElement("table");
|
||||
let probTBody = document.createElement("tbody");
|
||||
probTable.appendChild(probTBody);
|
||||
|
||||
for (const probToken of token.probabilities) {
|
||||
let tr = document.createElement("tr");
|
||||
let rgb = interpolateRGB(
|
||||
[255, 255, 255],
|
||||
[0, 255, 0],
|
||||
probToken.score
|
||||
).map(Math.round);
|
||||
let color = `rgb(${rgb.join(", ")})`;
|
||||
|
||||
if (probToken.decoded === token.decoded) {
|
||||
tr.classList.add("token-probs-final-token");
|
||||
}
|
||||
|
||||
let tds = {};
|
||||
|
||||
for (const property of ["tokenId", "decoded", "score"]) {
|
||||
let td = document.createElement("td");
|
||||
td.style.color = color;
|
||||
tds[property] = td;
|
||||
tr.appendChild(td);
|
||||
}
|
||||
|
||||
tds.tokenId.innerText = probToken.tokenId;
|
||||
tds.decoded.innerText = probToken.decoded.toString().replaceAll("\n", "\\n");
|
||||
tds.score.innerText = (probToken.score * 100).toFixed(2) + "%";
|
||||
|
||||
probTBody.appendChild(tr);
|
||||
}
|
||||
|
||||
probDiv.appendChild(probTable);
|
||||
token_prob_container.append(probDiv);
|
||||
}
|
||||
}
|
||||
|
||||
scrollToBottom();
|
||||
} else if(msg.cmd == "updatescreen") {
|
||||
var _gamestarted = gamestarted;
|
||||
|
@ -2561,6 +2631,14 @@ $(document).ready(function(){
|
|||
} else if(msg.cmd == "updateoutputstreaming") {
|
||||
// Update toggle state
|
||||
$("#setoutputstreaming").prop('checked', msg.data).change();
|
||||
} else if(msg.cmd == "updateshowprobs") {
|
||||
$("#setshowprobs").prop('checked', msg.data).change();
|
||||
|
||||
if(msg.data) {
|
||||
token_prob_menu.removeClass("hidden");
|
||||
} else {
|
||||
token_prob_menu.addClass("hidden");
|
||||
}
|
||||
} else if(msg.cmd == "allowtoggle") {
|
||||
// Allow toggle change states to propagate
|
||||
allowtoggle = msg.data;
|
||||
|
@ -2956,7 +3034,7 @@ $(document).ready(function(){
|
|||
});
|
||||
|
||||
button_actretry.on("click", function(ev) {
|
||||
ignore_stream = false;
|
||||
beginStream();
|
||||
hideMessage();
|
||||
socket.send({'cmd': 'retry', 'chatname': chatmode ? chat_name.val() : undefined, 'data': ''});
|
||||
hidegenseqs();
|
||||
|
@ -3203,7 +3281,7 @@ $(document).ready(function(){
|
|||
});
|
||||
|
||||
rs_accept.on("click", function(ev) {
|
||||
ignore_stream = false;
|
||||
beginStream();
|
||||
hideMessage();
|
||||
socket.send({'cmd': 'rndgame', 'memory': $("#rngmemory").val(), 'data': topic.val()});
|
||||
hideRandomStoryPopup();
|
||||
|
|
|
@ -1647,4 +1647,51 @@ body.connected .popupfooter, .popupfooter.always-available {
|
|||
.breadcrumbitem:hover {
|
||||
cursor: pointer;
|
||||
background-color: #688f1f;
|
||||
}
|
||||
}
|
||||
|
||||
#token_prob_menu {
|
||||
color: white;
|
||||
background-color: #262626;
|
||||
}
|
||||
|
||||
.token-probs {
|
||||
display: inline-block;
|
||||
text-align: center;
|
||||
margin-right: 5px;
|
||||
}
|
||||
|
||||
.token-probs > table {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.token-probs > table > tbody > tr > td {
|
||||
border: 1px solid #262626;
|
||||
border-collapse: collapse;
|
||||
padding: 2px 15px;
|
||||
}
|
||||
|
||||
.token-probs > table > tbody > tr {
|
||||
background-color: #3e3e3e;
|
||||
}
|
||||
|
||||
.token-probs > table > tbody > tr:nth-child(2n) {
|
||||
background-color: #575757;
|
||||
}
|
||||
|
||||
.token-probs-final-token {
|
||||
font-weight: bold;
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
.token-probs-final-token > td {
|
||||
background: #5c8a5a;
|
||||
}
|
||||
|
||||
.token-probs-header {
|
||||
display: block;
|
||||
}
|
||||
|
||||
#token_prob_container {
|
||||
overflow-x: scroll;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
|
|
@ -123,6 +123,11 @@
|
|||
</div>
|
||||
<div class="row" id="formatmenu">
|
||||
</div>
|
||||
|
||||
<div id="token_prob_menu" class="row hidden">
|
||||
<div id="token_prob_container"></div>
|
||||
</div>
|
||||
|
||||
<div class="layer-container">
|
||||
<div class="layer-bottom row" id="gamescreen">
|
||||
<span id="gametext" contenteditable="true"><p>...</p></span>
|
||||
|
|
Loading…
Reference in New Issue