Implement Token Probabilities UI using logprobs

This commit is contained in:
khanon
2024-01-22 23:00:31 -06:00
parent 9b42be2334
commit 60044c18a4
16 changed files with 921 additions and 35 deletions

View File

@ -105,6 +105,7 @@ import {
nai_settings,
adjustNovelInstructionPrompt,
loadNovelSubscriptionData,
parseNovelAILogprobs,
} from './scripts/nai-settings.js';
import {
@ -169,6 +170,7 @@ import { markdownExclusionExt } from './scripts/showdown-exclusion.js';
import { NOTE_MODULE_NAME, initAuthorsNote, metadata_keys, setFloatingPrompt, shouldWIAddPrompt } from './scripts/authors-note.js';
import { registerPromptManagerMigration } from './scripts/PromptManager.js';
import { getRegexedString, regex_placement } from './scripts/extensions/regex/engine.js';
import { initLogprobs, saveLogprobsForActiveMessage } from './scripts/logprobs.js';
import { FILTER_TYPES, FilterHelper } from './scripts/filters.js';
import { getCfgPrompt, getGuidanceScale, initCfg } from './scripts/cfg-scale.js';
import {
@ -197,6 +199,7 @@ import { evaluateMacros } from './scripts/macros.js';
//exporting functions and vars for mods
export {
Generate,
cleanUpMessage,
getSettings,
saveSettings,
saveSettingsDebounced,
@ -204,6 +207,7 @@ export {
clearChat,
getChat,
getCharacters,
getGeneratingApi,
callPopup,
substituteParams,
sendSystemMessage,
@ -773,6 +777,7 @@ async function firstLoadInit() {
initRossMods();
initStats();
initCfg();
initLogprobs();
doDailyExtensionUpdatesCheck();
hideLoader();
await eventSource.emit(event_types.APP_READY);
@ -2392,6 +2397,8 @@ class StreamingProcessor {
this.timeStarted = timeStarted;
this.messageAlreadyGenerated = messageAlreadyGenerated;
this.swipes = [];
/** @type {import('./scripts/logprobs.js').TokenLogprobs[]} */
this.messageLogprobs = [];
}
showMessageButtons(messageId) {
@ -2522,7 +2529,9 @@ class StreamingProcessor {
await eventSource.emit(event_types.IMPERSONATE_READY, text);
}
const continueMsg = this.type === 'continue' ? this.messageAlreadyGenerated : undefined;
await saveChatConditional();
saveLogprobsForActiveMessage(this.messageLogprobs.filter(Boolean), continueMsg);
activateSendButtons();
showSwipeButtons();
setGenerationProgress(0);
@ -2608,7 +2617,7 @@ class StreamingProcessor {
try {
const sw = new Stopwatch(1000 / power_user.streaming_fps);
const timestamps = [];
for await (const { text, swipes } of this.generator()) {
for await (const { text, swipes, logprobs } of this.generator()) {
timestamps.push(Date.now());
if (this.isStopped) {
return;
@ -2616,6 +2625,9 @@ class StreamingProcessor {
this.result = text;
this.swipes = swipes;
if (logprobs) {
this.messageLogprobs.push(...(Array.isArray(logprobs) ? logprobs : [logprobs]));
}
await sw.tick(() => this.onProgressStreaming(this.messageId, this.messageAlreadyGenerated + text));
}
const seconds = (timestamps[timestamps.length - 1] - timestamps[0]) / 1000;
@ -3699,6 +3711,9 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
else {
({ type, getMessage } = await saveReply('appendFinal', getMessage, false, title, swipes));
}
// This relies on `saveReply` having been called to add the message to the chat, so it must be last.
parseAndSaveLogprobs(data, continue_mag);
}
if (type !== 'quiet') {
@ -4308,6 +4323,34 @@ function extractTitleFromData(data) {
return undefined;
}
/**
* parseAndSaveLogprobs receives the full data response for a non-streaming
* generation, parses logprobs for all tokens in the message, and saves them
* to the currently active message.
* @param {object} data - response data containing all tokens/logprobs
* @param {string} continueFrom - for 'continue' generations, the prompt
* */
function parseAndSaveLogprobs(data, continueFrom) {
/** @type {import('./scripts/logprobs.js').TokenLogprobs[] | null} */
let logprobs = null;
switch (main_api) {
case 'novel':
// parser only handles one token/logprob pair at a time
logprobs = data.logprobs?.map(parseNovelAILogprobs) || null;
break;
case 'openai':
// OAI and other chat completion APIs must handle this earlier in
// `sendOpenAIRequest`. `data` for these APIs is just a string with
// the text of the generated message, logprobs are not included.
return;
default:
return;
}
saveLogprobsForActiveMessage(logprobs, continueFrom);
}
/**
* Extracts the message from the response data.
* @param {object} data Response data