553 lines
19 KiB
JavaScript
553 lines
19 KiB
JavaScript
import {
|
|
animation_duration,
|
|
chat,
|
|
cleanUpMessage,
|
|
event_types,
|
|
eventSource,
|
|
Generate,
|
|
getGeneratingApi,
|
|
is_send_press,
|
|
isStreamingEnabled,
|
|
} from '../script.js';
|
|
import { debounce, delay, getStringHash } from './utils.js';
|
|
import { decodeTextTokens, getTokenizerBestMatch } from './tokenizers.js';
|
|
import { power_user } from './power-user.js';
|
|
import { callGenericPopup, POPUP_TYPE } from './popup.js';
|
|
import { t } from './i18n.js';
|
|
|
|
const TINTS = 4;
|
|
const MAX_MESSAGE_LOGPROBS = 100;
|
|
const REROLL_BUTTON = $('#logprobsReroll');
|
|
|
|
/**
|
|
* Tuple of a candidate token and its logarithm of probability of being chosen
|
|
* @typedef {[string, number]} Candidate - (token, logprob)
|
|
*/
|
|
|
|
/**
|
|
* @typedef {(Node|JQuery<Text>|JQuery<HTMLElement>)[]} NodeArray - Array of DOM nodes
|
|
*/
|
|
|
|
/**
|
|
* Logprob data for a single message
|
|
* @typedef {Object} MessageLogprobData
|
|
* @property {number} created - timestamp of when the message was generated
|
|
* @property {number} hash - hash of the message object
|
|
* @property {number} messageId - ID of the source message
|
|
* @property {number} swipeId - ID of the source swipe on the source message
|
|
* @property {string} api - API used to generate the message
|
|
* @property {TokenLogprobs[]} messageLogprobs Logprob data for each token, by
|
|
* its index in the message
|
|
* @property {string | null} continueFrom - the 'continue' prefix used to
|
|
* generate the message, if any
|
|
*/
|
|
|
|
/**
|
|
* Logprob data for a single token
|
|
* @typedef {Object} TokenLogprobs
|
|
* @property {string} token - A token generated by the model
|
|
* @property {Candidate[]} topLogprobs - Array of top candidate tokens
|
|
*/
|
|
|
|
/**
|
|
* State object for Token Probabilities
|
|
* @typedef {Object} LogprobsState
|
|
* @property {?TokenLogprobs} selectedTokenLogprobs Log probabilities for
|
|
* currently-selected token.
|
|
* @property {Map<number, MessageLogprobData>} messageLogprobs Log probabilities for
|
|
* each message, keyed by message hash.
|
|
*/
|
|
|
|
/**
|
|
* @type {LogprobsState} state
|
|
*/
|
|
const state = {
|
|
selectedTokenLogprobs: null,
|
|
messageLogprobs: new Map(),
|
|
};
|
|
|
|
/**
|
|
* Renders the Token Probabilities UI and all subviews with the active message's
|
|
* logprobs data. If the message has no token logprobs, a message is displayed.
|
|
*/
|
|
function renderAlternativeTokensView() {
|
|
const view = $('#logprobs_generation_output');
|
|
if (!view.is(':visible')) {
|
|
return;
|
|
}
|
|
view.empty();
|
|
state.selectedTokenLogprobs = null;
|
|
renderTopLogprobs();
|
|
|
|
const { messageLogprobs, continueFrom } = getActiveMessageLogprobData() || {};
|
|
const usingSmoothStreaming = isStreamingEnabled() && power_user.smooth_streaming;
|
|
if (!messageLogprobs?.length || usingSmoothStreaming) {
|
|
const emptyState = $('<div></div>');
|
|
const noTokensMsg = !power_user.request_token_probabilities
|
|
? '<span>Enable <b>Request token probabilities</b> in the User Settings menu to use this feature.</span>'
|
|
: usingSmoothStreaming
|
|
? t`Token probabilities are not available when using Smooth Streaming.`
|
|
: is_send_press
|
|
? t`Generation in progress...`
|
|
: t`No token probabilities available for the current message.`;
|
|
emptyState.html(noTokensMsg);
|
|
emptyState.addClass('logprobs_empty_state');
|
|
view.append(emptyState);
|
|
return;
|
|
}
|
|
|
|
const prefix = continueFrom || '';
|
|
const tokenSpans = [];
|
|
REROLL_BUTTON.toggle(!!prefix);
|
|
|
|
if (prefix) {
|
|
REROLL_BUTTON.off('click').on('click', () => onPrefixClicked(prefix.length));
|
|
|
|
let cumulativeOffset = 0;
|
|
const words = prefix.split(/\s+/);
|
|
const delimiters = prefix.match(/\s+/g) || []; // Capture the actual delimiters
|
|
|
|
words.forEach((word, i) => {
|
|
const span = $('<span></span>');
|
|
span.text(`${word} `);
|
|
|
|
span.addClass('logprobs_output_prefix');
|
|
span.attr('title', t`Reroll from this point`);
|
|
|
|
let offset = cumulativeOffset;
|
|
span.on('click', () => onPrefixClicked(offset));
|
|
addKeyboardProps(span);
|
|
|
|
tokenSpans.push(span);
|
|
tokenSpans.push(delimiters[i]?.includes('\n')
|
|
? document.createElement('br')
|
|
: document.createTextNode(delimiters[i] || ' '),
|
|
);
|
|
|
|
cumulativeOffset += word.length + (delimiters[i]?.length || 0);
|
|
});
|
|
}
|
|
|
|
messageLogprobs.forEach((tokenData, i) => {
|
|
const { token } = tokenData;
|
|
const span = $('<span></span>');
|
|
const text = toVisibleWhitespace(token);
|
|
span.text(text);
|
|
span.addClass('logprobs_output_token');
|
|
span.addClass('logprobs_tint_' + (i % TINTS));
|
|
span.on('click', () => onSelectedTokenChanged(tokenData, span));
|
|
addKeyboardProps(span);
|
|
tokenSpans.push(...withVirtualWhitespace(token, span));
|
|
});
|
|
|
|
view.append(tokenSpans);
|
|
|
|
// scroll past long prior context
|
|
if (prefix) {
|
|
const element = view.find('.logprobs_output_token').first();
|
|
const scrollOffset = element.offset().top - element.parent().offset().top;
|
|
element.parent().scrollTop(scrollOffset);
|
|
}
|
|
}
|
|
|
|
function addKeyboardProps(element) {
|
|
element.attr('role', 'button');
|
|
element.attr('tabindex', '0');
|
|
element.keydown(function (e) {
|
|
if (e.key === 'Enter' || e.key === ' ') {
|
|
element.click();
|
|
}
|
|
});
|
|
}
|
|
|
|
/**
|
|
* renderTopLogprobs renders the top logprobs subview with the currently
|
|
* selected token highlighted. If no token is selected, the subview is hidden.
|
|
*
|
|
* Callers:
|
|
* - renderAlternativeTokensView, to render the entire view
|
|
* - onSelectedTokenChanged, to update the view when a token is selected
|
|
*/
|
|
function renderTopLogprobs() {
|
|
$('#logprobs_top_logprobs_hint').hide();
|
|
const view = $('.logprobs_candidate_list');
|
|
view.empty();
|
|
|
|
if (!state.selectedTokenLogprobs) {
|
|
return;
|
|
}
|
|
|
|
const { token: selectedToken, topLogprobs } = state.selectedTokenLogprobs;
|
|
|
|
let sum = 0;
|
|
const nodes = [];
|
|
const candidates = topLogprobs
|
|
.sort(([, logA], [, logB]) => logB - logA)
|
|
.map(([text, log]) => {
|
|
if (log <= 0) {
|
|
const probability = Math.exp(log);
|
|
sum += probability;
|
|
return [text, probability, log];
|
|
} else {
|
|
return [text, log, null];
|
|
}
|
|
});
|
|
candidates.push(['<others>', 1 - sum, 0]);
|
|
|
|
let matched = false;
|
|
for (const [token, probability, log] of candidates) {
|
|
const container = $('<button class="flex-container flexFlowColumn logprobs_top_candidate"></button>');
|
|
const tokenNormalized = String(token).replace(/^[▁Ġ]/g, ' ');
|
|
|
|
if (token === selectedToken || tokenNormalized === selectedToken) {
|
|
matched = true;
|
|
container.addClass('selected');
|
|
}
|
|
|
|
const tokenText = $('<span></span>').text(`${toVisibleWhitespace(token.toString())}`);
|
|
const percentText = $('<span></span>').text(`${(+probability * 100).toFixed(2)}%`);
|
|
container.append(tokenText, percentText);
|
|
if (log) {
|
|
container.attr('title', `logarithm: ${log}`);
|
|
}
|
|
addKeyboardProps(container);
|
|
if (token !== '<others>') {
|
|
container.on('click', () => onAlternativeClicked(state.selectedTokenLogprobs, token.toString()));
|
|
} else {
|
|
container.prop('disabled', true);
|
|
}
|
|
nodes.push(container);
|
|
}
|
|
|
|
// Highlight the <others> node if the selected token was not included in the
|
|
// top logprobs
|
|
if (!matched) {
|
|
nodes[nodes.length - 1].css('background-color', 'rgba(255, 0, 0, 0.1)');
|
|
}
|
|
|
|
view.append(nodes);
|
|
}
|
|
|
|
/**
|
|
* User clicks on a token in the token output view. It updates the selected token state
|
|
* and re-renders the top logprobs view, or deselects the token if it was already selected.
|
|
* @param {TokenLogprobs} logprobs - logprob data for the selected token
|
|
* @param {Node|JQuery} span - target span node that was clicked
|
|
*/
|
|
function onSelectedTokenChanged(logprobs, span) {
|
|
$('.logprobs_output_token.selected').removeClass('selected');
|
|
if (state.selectedTokenLogprobs === logprobs) {
|
|
state.selectedTokenLogprobs = null;
|
|
} else {
|
|
state.selectedTokenLogprobs = logprobs;
|
|
$(span).addClass('selected');
|
|
}
|
|
renderTopLogprobs();
|
|
}
|
|
|
|
/**
|
|
* onAlternativeClicked is called when the user clicks on an alternative token
|
|
* in the top logprobs view. It will create a new swipe message and prefill it
|
|
* with all text up to the selected token, followed by the chosen alternative.
|
|
* Then it requests a `continue` completion from the model with the new prompt.
|
|
* @param {TokenLogprobs} tokenLogprobs - logprob data for selected alternative
|
|
* @param {string} alternative - selected alternative token's text
|
|
*/
|
|
function onAlternativeClicked(tokenLogprobs, alternative) {
|
|
if (!checkGenerateReady()) {
|
|
return;
|
|
}
|
|
|
|
if (getGeneratingApi() === 'openai') {
|
|
const title = t`Feature unavailable`;
|
|
const message = t`Due to API limitations, rerolling a token is not supported with OpenAI. Try switching to a different API.`;
|
|
const content = `<h3>${title}</h3><p>${message}</p>`;
|
|
return callGenericPopup(content, POPUP_TYPE.TEXT);
|
|
}
|
|
|
|
const { messageLogprobs, continueFrom } = getActiveMessageLogprobData();
|
|
const replaceIndex = messageLogprobs.findIndex(x => x === tokenLogprobs);
|
|
|
|
const tokens = messageLogprobs.slice(0, replaceIndex + 1).map(({ token }) => token);
|
|
tokens[replaceIndex] = String(alternative).replace(/^[▁Ġ]/g, ' ').replace(/Ċ/g, '\n');
|
|
|
|
const prefix = continueFrom || '';
|
|
const prompt = prefix + tokens.join('');
|
|
addGeneration(prompt);
|
|
}
|
|
|
|
/**
|
|
* User clicks on the reroll button in the token output view, or on a word in the
|
|
* prefix. Retrieve the prefix for the current message and truncate it at the
|
|
* offset for the selected word. Then request a `continue` completion from the
|
|
* model with the new prompt.
|
|
*
|
|
* If no offset is provided, the entire prefix will be rerolled.
|
|
*
|
|
* @param {number} offset - index of the token in the prefix to reroll from
|
|
* @returns {void}
|
|
* @param offset
|
|
*/
|
|
function onPrefixClicked(offset = undefined) {
|
|
if (!checkGenerateReady()) {
|
|
return;
|
|
}
|
|
|
|
const { continueFrom } = getActiveMessageLogprobData() || {};
|
|
const prefix = continueFrom ? continueFrom.substring(0, offset) : '';
|
|
addGeneration(prefix);
|
|
}
|
|
|
|
function checkGenerateReady() {
|
|
if (is_send_press) {
|
|
toastr.warning('Please wait for the current generation to complete.');
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
/**
|
|
* Generates a new swipe as a continuation of the given prompt, when user selects
|
|
* an alternative token or rerolls from a prefix.
|
|
*
|
|
* @param prompt
|
|
*/
|
|
function addGeneration(prompt) {
|
|
const messageId = chat.length - 1;
|
|
if (prompt && prompt.length > 0) {
|
|
createSwipe(messageId, prompt);
|
|
$('.swipe_right:last').trigger('click');
|
|
void Generate('continue');
|
|
} else {
|
|
$('.swipe_right:last').trigger('click');
|
|
}
|
|
}
|
|
|
|
/**
|
|
* onToggleLogprobsPanel is called when the user performs an action that toggles
|
|
* the logprobs view, such as clicking the Token Probabilities menu item or the
|
|
* close button.
|
|
*/
|
|
function onToggleLogprobsPanel() {
|
|
const logprobsViewer = $('#logprobsViewer');
|
|
|
|
// largely copied from CFGScale toggle
|
|
if (logprobsViewer.css('display') === 'none') {
|
|
logprobsViewer.addClass('resizing');
|
|
logprobsViewer.css('display', 'flex');
|
|
logprobsViewer.css('opacity', 0.0);
|
|
renderAlternativeTokensView();
|
|
logprobsViewer.transition({
|
|
opacity: 1.0,
|
|
duration: animation_duration,
|
|
}, async function () {
|
|
await delay(50);
|
|
logprobsViewer.removeClass('resizing');
|
|
});
|
|
} else {
|
|
logprobsViewer.addClass('resizing');
|
|
logprobsViewer.transition({
|
|
opacity: 0.0,
|
|
duration: animation_duration,
|
|
},
|
|
async function () {
|
|
await delay(50);
|
|
logprobsViewer.removeClass('resizing');
|
|
});
|
|
setTimeout(function () {
|
|
logprobsViewer.hide();
|
|
}, animation_duration);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Appends a new swipe to the target chat message with the given text.
|
|
* @param {number} messageId - target chat message ID
|
|
* @param {string} prompt - initial prompt text which will be continued
|
|
*/
|
|
function createSwipe(messageId, prompt) {
|
|
// need to call `cleanUpMessage` on our new prompt, because we were working
|
|
// with raw model output and our new prompt is missing trimming/macro replacements
|
|
const cleanedPrompt = cleanUpMessage(prompt, false, false, true);
|
|
|
|
const msg = chat[messageId];
|
|
const newSwipeInfo = {
|
|
send_date: msg.send_date,
|
|
gen_started: msg.gen_started,
|
|
gen_finished: msg.gen_finished,
|
|
extra: { ...structuredClone(msg.extra), from_logprobs: new Date().getTime() },
|
|
};
|
|
|
|
msg.swipes = msg.swipes || [];
|
|
msg.swipe_info = msg.swipe_info || [];
|
|
|
|
// Add our new swipe, then make sure the active swipe is the one just before
|
|
// it. The call to `swipe_right` will switch to it immediately.
|
|
msg.swipes.push(cleanedPrompt);
|
|
msg.swipe_info.push(newSwipeInfo);
|
|
msg.swipe_id = Math.max(0, msg.swipes.length - 2);
|
|
}
|
|
|
|
/**
|
|
* toVisibleWhitespace receives input text and replaces spaces with · and
|
|
* newlines with ↵.
|
|
* @param {string} input
|
|
* @returns {string}
|
|
*/
|
|
function toVisibleWhitespace(input) {
|
|
return input.replace(/ /g, '·').replace(/[▁Ġ]/g, '·').replace(/[Ċ\n]/g, '↵');
|
|
}
|
|
|
|
/**
|
|
* withVirtualWhitespace inserts line breaks and a zero-width space before and
|
|
* after the span node if its token begins or ends with whitespace in order to
|
|
* allow text to wrap despite whitespace characters being replaced with a dot.
|
|
* @param {string} text - token text being evaluated for whitespace
|
|
* @param {Node|JQuery} span - target span node to be wrapped
|
|
* @returns {NodeArray} - array of nodes to be appended to the parent element
|
|
*/
|
|
function withVirtualWhitespace(text, span) {
|
|
/** @type {NodeArray} */
|
|
const result = [span];
|
|
if (text.match(/^\s/)) {
|
|
result.unshift(document.createTextNode('\u200b'));
|
|
}
|
|
if (text.match(/\s$/)) {
|
|
result.push($(document.createTextNode('\u200b')));
|
|
}
|
|
if (text.match(/^[▁Ġ]/)) {
|
|
result.unshift(document.createTextNode('\u200b'));
|
|
}
|
|
// line breaks are trickier. we don't currently handle consecutive line
|
|
// breaks or line breaks occuring in between non-whitespace characters, but
|
|
// tokenizers generally don't produce those anyway.
|
|
|
|
// matches leading line break, at least one character, and trailing line break
|
|
if (text.match(/^\n(?:.|\n)+\n$/)) {
|
|
result.unshift($('<br>'));
|
|
result.push($('<br>'));
|
|
} else if (text.match(/^\n/)) {
|
|
result.unshift($('<br>'));
|
|
} else if (text.match(/\n$/)) {
|
|
result.push($('<br>'));
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/**
|
|
* Receives the top logprobs for each token in a message and associates it with the active message.
|
|
*
|
|
* Ensure the active message has been updated and rendered before calling this function
|
|
* or the logprobs data will be saved to the wrong message.
|
|
*
|
|
* Callers:
|
|
* - Generate:onSuccess via saveLogprobsForActiveMessage, for non-streaming text completion
|
|
* - StreamingProcessor:onFinishStreaming, for streaming text completion
|
|
* - sendOpenAIRequest, for non-streaming chat completion
|
|
*
|
|
* @param {TokenLogprobs[]} logprobs - array of logprobs data for each token
|
|
* @param {string | null} continueFrom - for 'continue' generations, the prompt
|
|
*/
|
|
export function saveLogprobsForActiveMessage(logprobs, continueFrom) {
|
|
if (!logprobs) {
|
|
// non-streaming APIs could return null data
|
|
return;
|
|
}
|
|
|
|
// NovelAI only returns token IDs in logprobs data; convert to text tokens in-place
|
|
if (getGeneratingApi() === 'novel') {
|
|
convertTokenIdLogprobsToText(logprobs);
|
|
}
|
|
|
|
const msgId = chat.length - 1;
|
|
/** @type {MessageLogprobData} */
|
|
const data = {
|
|
created: new Date().getTime(),
|
|
api: getGeneratingApi(),
|
|
messageId: msgId,
|
|
swipeId: chat[msgId].swipe_id,
|
|
messageLogprobs: logprobs,
|
|
continueFrom,
|
|
hash: getMessageHash(chat[msgId]),
|
|
};
|
|
|
|
state.messageLogprobs.set(data.hash, data);
|
|
|
|
// Clean up old logprobs data
|
|
const oldLogprobs = Array.from(state.messageLogprobs.values())
|
|
.sort((a, b) => b.created - a.created)
|
|
.slice(MAX_MESSAGE_LOGPROBS);
|
|
for (const oldData of oldLogprobs) {
|
|
state.messageLogprobs.delete(oldData.hash);
|
|
}
|
|
}
|
|
|
|
function getMessageHash(message) {
|
|
// We don't use the swipe ID as a hash component because it's not stable,
|
|
// deleting a swipe will change the ID of all subsequent swipes.
|
|
const hashParams = {
|
|
name: message.name,
|
|
mid: chat.indexOf(message),
|
|
text: message.mes,
|
|
};
|
|
return getStringHash(JSON.stringify(hashParams));
|
|
}
|
|
|
|
/**
|
|
* getActiveMessageLogprobData returns the logprobs data for the active chat
|
|
* message.
|
|
* @returns {MessageLogprobData || null}
|
|
*/
|
|
function getActiveMessageLogprobData() {
|
|
const hash = getMessageHash(chat[chat.length - 1]);
|
|
return state.messageLogprobs.get(hash) || null;
|
|
}
|
|
|
|
|
|
/**
|
|
* convertLogprobTokenIdsToText replaces token IDs in logprobs data with text tokens,
|
|
* for APIs that return token IDs instead of text tokens, to wit: NovelAI.
|
|
*
|
|
* @param {TokenLogprobs[]} input - logprobs data with numeric token IDs
|
|
*/
|
|
function convertTokenIdLogprobsToText(input) {
|
|
const api = getGeneratingApi();
|
|
if (api !== 'novel') {
|
|
// should have been checked by the caller
|
|
throw new Error('convertTokenIdLogprobsToText should only be called for NovelAI');
|
|
}
|
|
|
|
const tokenizerId = getTokenizerBestMatch(api);
|
|
|
|
/** @type {any[]} Flatten unique token IDs across all logprobs */
|
|
const tokenIds = Array.from(new Set(input.flatMap(logprobs =>
|
|
logprobs.topLogprobs.map(([token]) => token).concat(logprobs.token),
|
|
)));
|
|
|
|
// Submit token IDs to tokenizer to get token text, then build ID->text map
|
|
// noinspection JSCheckFunctionSignatures - mutates input in-place
|
|
const { chunks } = decodeTextTokens(tokenizerId, tokenIds);
|
|
const tokenIdText = new Map(tokenIds.map((id, i) => [id, chunks[i]]));
|
|
|
|
// Fixup logprobs data with token text
|
|
input.forEach(logprobs => {
|
|
logprobs.token = tokenIdText.get(logprobs.token);
|
|
logprobs.topLogprobs = logprobs.topLogprobs.map(([token, logprob]) =>
|
|
[tokenIdText.get(token), logprob],
|
|
);
|
|
});
|
|
}
|
|
|
|
export function initLogprobs() {
|
|
REROLL_BUTTON.hide();
|
|
const debouncedRender = debounce(renderAlternativeTokensView);
|
|
$('#logprobsViewerClose').on('click', onToggleLogprobsPanel);
|
|
$('#option_toggle_logprobs').on('click', onToggleLogprobsPanel);
|
|
eventSource.on(event_types.CHAT_CHANGED, debouncedRender);
|
|
eventSource.on(event_types.CHARACTER_MESSAGE_RENDERED, debouncedRender);
|
|
eventSource.on(event_types.IMPERSONATE_READY, debouncedRender);
|
|
eventSource.on(event_types.MESSAGE_DELETED, debouncedRender);
|
|
eventSource.on(event_types.MESSAGE_EDITED, debouncedRender);
|
|
eventSource.on(event_types.MESSAGE_SWIPED, debouncedRender);
|
|
}
|