Merge pull request #1734 from khanonnie/alternative-tokens

Implement Token Probabilities UI panel using logprobs
This commit is contained in:
Cohee
2024-01-26 03:39:25 +02:00
committed by GitHub
16 changed files with 921 additions and 35 deletions

View File

@@ -105,6 +105,7 @@ import {
nai_settings,
adjustNovelInstructionPrompt,
loadNovelSubscriptionData,
parseNovelAILogprobs,
} from './scripts/nai-settings.js';
import {
@@ -169,6 +170,7 @@ import { markdownExclusionExt } from './scripts/showdown-exclusion.js';
import { NOTE_MODULE_NAME, initAuthorsNote, metadata_keys, setFloatingPrompt, shouldWIAddPrompt } from './scripts/authors-note.js';
import { registerPromptManagerMigration } from './scripts/PromptManager.js';
import { getRegexedString, regex_placement } from './scripts/extensions/regex/engine.js';
import { initLogprobs, saveLogprobsForActiveMessage } from './scripts/logprobs.js';
import { FILTER_TYPES, FilterHelper } from './scripts/filters.js';
import { getCfgPrompt, getGuidanceScale, initCfg } from './scripts/cfg-scale.js';
import {
@@ -197,6 +199,7 @@ import { evaluateMacros } from './scripts/macros.js';
//exporting functions and vars for mods
export {
Generate,
cleanUpMessage,
getSettings,
saveSettings,
saveSettingsDebounced,
@@ -204,6 +207,7 @@ export {
clearChat,
getChat,
getCharacters,
getGeneratingApi,
callPopup,
substituteParams,
sendSystemMessage,
@@ -824,6 +828,7 @@ async function firstLoadInit() {
initRossMods();
initStats();
initCfg();
initLogprobs();
doDailyExtensionUpdatesCheck();
hideLoader();
await eventSource.emit(event_types.APP_READY);
@@ -2475,6 +2480,8 @@ class StreamingProcessor {
this.timeStarted = timeStarted;
this.messageAlreadyGenerated = messageAlreadyGenerated;
this.swipes = [];
/** @type {import('./scripts/logprobs.js').TokenLogprobs[]} */
this.messageLogprobs = [];
}
showMessageButtons(messageId) {
@@ -2606,7 +2613,9 @@ class StreamingProcessor {
await eventSource.emit(event_types.IMPERSONATE_READY, text);
}
const continueMsg = this.type === 'continue' ? this.messageAlreadyGenerated : undefined;
await saveChatConditional();
saveLogprobsForActiveMessage(this.messageLogprobs.filter(Boolean), continueMsg);
activateSendButtons();
showSwipeButtons();
setGenerationProgress(0);
@@ -2692,7 +2701,7 @@ class StreamingProcessor {
try {
const sw = new Stopwatch(1000 / power_user.streaming_fps);
const timestamps = [];
for await (const { text, swipes } of this.generator()) {
for await (const { text, swipes, logprobs } of this.generator()) {
timestamps.push(Date.now());
if (this.isStopped) {
return;
@@ -2700,6 +2709,9 @@ class StreamingProcessor {
this.result = text;
this.swipes = swipes;
if (logprobs) {
this.messageLogprobs.push(...(Array.isArray(logprobs) ? logprobs : [logprobs]));
}
await sw.tick(() => this.onProgressStreaming(this.messageId, this.messageAlreadyGenerated + text));
}
const seconds = (timestamps[timestamps.length - 1] - timestamps[0]) / 1000;
@@ -3783,6 +3795,9 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
else {
({ type, getMessage } = await saveReply('appendFinal', getMessage, false, title, swipes));
}
// This relies on `saveReply` having been called to add the message to the chat, so it must be last.
parseAndSaveLogprobs(data, continue_mag);
}
if (type !== 'quiet') {
@@ -4392,6 +4407,34 @@ function extractTitleFromData(data) {
return undefined;
}
/**
* parseAndSaveLogprobs receives the full data response for a non-streaming
* generation, parses logprobs for all tokens in the message, and saves them
* to the currently active message.
* @param {object} data - response data containing all tokens/logprobs
* @param {string} continueFrom - for 'continue' generations, the prompt
* */
function parseAndSaveLogprobs(data, continueFrom) {
/** @type {import('./scripts/logprobs.js').TokenLogprobs[] | null} */
let logprobs = null;
switch (main_api) {
case 'novel':
// parser only handles one token/logprob pair at a time
logprobs = data.logprobs?.map(parseNovelAILogprobs) || null;
break;
case 'openai':
// OAI and other chat completion APIs must handle this earlier in
// `sendOpenAIRequest`. `data` for these APIs is just a string with
// the text of the generated message, logprobs are not included.
return;
default:
return;
}
saveLogprobsForActiveMessage(logprobs, continueFrom);
}
/**
* Extracts the message from the response data.
* @param {object} data Response data