Implement Token Probabilities UI using logprobs

This commit is contained in:
khanon
2024-01-22 23:00:31 -06:00
parent 9b42be2334
commit 60044c18a4
16 changed files with 921 additions and 35 deletions

View File

@ -63,6 +63,7 @@ import {
formatInstructModeSystemPrompt,
} from './instruct-mode.js';
import { isMobile } from './RossAscends-mods.js';
import { saveLogprobsForActiveMessage } from './logprobs.js';
export {
openai_messages_count,
@ -1534,6 +1535,7 @@ async function sendOpenAIRequest(type, messages, signal) {
const isImpersonate = type === 'impersonate';
const isContinue = type === 'continue';
const stream = oai_settings.stream_openai && !isQuiet && !isScale && !isAI21 && !(isGoogle && oai_settings.google_model.includes('bison'));
const useLogprobs = !!power_user.request_token_probabilities;
if (isTextCompletion && isOpenRouter) {
messages = convertChatCompletionToInstruct(messages, type);
@ -1601,6 +1603,11 @@ async function sendOpenAIRequest(type, messages, signal) {
generate_data['proxy_password'] = oai_settings.proxy_password;
}
// Add logprobs request (currently OpenAI only, max 5 on their side)
if (useLogprobs && isOAI) {
generate_data['logprobs'] = 5;
}
if (isClaude) {
generate_data['top_k'] = Number(oai_settings.top_k_openai);
generate_data['exclude_assistant'] = oai_settings.exclude_assistant;
@ -1689,8 +1696,9 @@ async function sendOpenAIRequest(type, messages, signal) {
const rawData = isSSEStream ? value.data : utf8Decoder.decode(value, { stream: true });
if (isSSEStream && rawData === '[DONE]') return;
tryParseStreamingError(response, rawData);
text += getStreamingReply(JSON.parse(rawData));
yield { text, swipes: [] };
const parsed = JSON.parse(rawData);
text += getStreamingReply(parsed);
yield { text, swipes: [], logprobs: parseChatCompletionLogprobs(parsed) };
}
};
}
@ -1705,6 +1713,13 @@ async function sendOpenAIRequest(type, messages, signal) {
throw new Error(data);
}
if (type !== 'quiet') {
const logprobs = parseChatCompletionLogprobs(data);
// Delay is required to allow the active message to be updated to
// the one we are generating (happens right after sendOpenAIRequest)
delay(1).then(() => saveLogprobsForActiveMessage(logprobs, null));
}
return !isTextCompletion ? data.choices[0]['message']['content'] : data.choices[0]['text'];
}
}
@ -1719,6 +1734,88 @@ function getStreamingReply(data) {
}
}
/**
* parseChatCompletionLogprobs converts the response data returned from a chat
* completions-like source into an array of TokenLogprobs found in the response.
* @param {Object} data - response data from a chat completions-like source
* @returns {import('logprobs.js').TokenLogprobs[] | null} converted logprobs
*/
function parseChatCompletionLogprobs(data) {
if (!data) {
return null;
}
switch (oai_settings.chat_completion_source) {
case chat_completion_sources.OPENAI:
if (!data.choices?.length) {
return null;
}
// OpenAI Text Completion API is treated as a chat completion source
// by SillyTavern, hence its presence in this function.
return textCompletionModels.includes(oai_settings.openai_model)
? parseOpenAITextLogprobs(data.choices[0]?.logprobs)
: parseOpenAIChatLogprobs(data.choices[0]?.logprobs);
default:
// implement other chat completion sources here
}
return null;
}
/**
* parseOpenAIChatLogprobs receives a `logprobs` response from OpenAI's chat
* completion API and converts into the structure used by the Token Probabilities
* view.
* @param {{content: { token: string, logprob: number, top_logprobs: { token: string, logprob: number }[] }[]}} logprobs
* @returns {import('logprobs.js').TokenLogprobs[] | null} converted logprobs
*/
function parseOpenAIChatLogprobs(logprobs) {
const { content } = logprobs ?? {};
if (!Array.isArray(content)) {
return null;
}
/** @type {({ token: string, logprob: number }) => [string, number]} */
const toTuple = (x) => [x.token, x.logprob];
return content.map(({ token, logprob, top_logprobs }) => {
// Add the chosen token to top_logprobs if it's not already there, then
// convert to a list of [token, logprob] pairs
const chosenTopToken = top_logprobs.some((top) => token === top.token);
const topLogprobs = chosenTopToken
? top_logprobs.map(toTuple)
: [...top_logprobs.map(toTuple), [token, logprob]];
return { token, topLogprobs };
});
}
/**
* parseOpenAITextLogprobs receives a `logprobs` response from OpenAI's text
* completion API and converts into the structure used by the Token Probabilities
* view.
* @param {{tokens: string[], token_logprobs: number[], top_logprobs: { token: string, logprob: number }[][]}} logprobs
* @returns {import('logprobs.js').TokenLogprobs[] | null} converted logprobs
*/
function parseOpenAITextLogprobs(logprobs) {
const { tokens, token_logprobs, top_logprobs } = logprobs ?? {};
if (!Array.isArray(tokens)) {
return null;
}
return tokens.map((token, i) => {
// Add the chosen token to top_logprobs if it's not already there, then
// convert to a list of [token, logprob] pairs
const topLogprobs = top_logprobs[i] ? Object.entries(top_logprobs[i]) : [];
const chosenTopToken = topLogprobs.some(([topToken]) => token === topToken);
if (!chosenTopToken) {
topLogprobs.push([token, token_logprobs[i]]);
}
return { token, topLogprobs };
});
}
function handleWindowError(err) {
const text = parseWindowError(err);
toastr.error(text, 'Window.ai returned an error');