Add WebLLM extension summarization

This commit is contained in:
Cohee 2024-08-12 23:01:03 +03:00
parent 77ab694ea0
commit 8685c2f471
3 changed files with 165 additions and 43 deletions

View File

@ -25,6 +25,7 @@ import { SlashCommandParser } from '../../slash-commands/SlashCommandParser.js';
import { SlashCommand } from '../../slash-commands/SlashCommand.js';
import { ARGUMENT_TYPE, SlashCommandArgument, SlashCommandNamedArgument } from '../../slash-commands/SlashCommandArgument.js';
import { MacrosParser } from '../../macros.js';
import { countWebLlmTokens, generateWebLlmChatPrompt, getWebLlmContextSize, isWebLlmSupported } from '../shared.js';
export { MODULE_NAME };
const MODULE_NAME = '1_memory';
@ -36,6 +37,40 @@ let lastMessageHash = null;
let lastMessageId = null;
let inApiCall = false;
/**
* Count the number of tokens in the provided text.
* @param {string} text Text to count tokens for
* @returns {Promise<number>} Number of tokens in the text
*/
async function countSourceTokens(text, padding = 0) {
if (extension_settings.memory.source === summary_sources.webllm) {
const count = await countWebLlmTokens(text);
return count + padding;
}
if (extension_settings.memory.source === summary_sources.extras) {
const count = getTextTokens(tokenizers.GPT2, text).length;
return count + padding;
}
return await getTokenCountAsync(text, padding);
}
async function getSourceContextSize() {
const overrideLength = extension_settings.memory.overrideResponseLength;
if (extension_settings.memory.source === summary_sources.webllm) {
const maxContext = await getWebLlmContextSize();
return overrideLength > 0 ? (maxContext - overrideLength) : Math.round(maxContext * 0.75);
}
if (extension_settings.source === summary_sources.extras) {
return 1024;
}
return getMaxContextSize(overrideLength);
}
const formatMemoryValue = function (value) {
if (!value) {
return '';
@ -55,6 +90,7 @@ const saveChatDebounced = debounce(() => getContext().saveChat(), debounce_timeo
const summary_sources = {
'extras': 'extras',
'main': 'main',
'webllm': 'webllm',
};
const prompt_builders = {
@ -130,12 +166,12 @@ function loadSettings() {
async function onPromptForceWordsAutoClick() {
const context = getContext();
const maxPromptLength = getMaxContextSize(extension_settings.memory.overrideResponseLength);
const maxPromptLength = await getSourceContextSize();
const chat = context.chat;
const allMessages = chat.filter(m => !m.is_system && m.mes).map(m => m.mes);
const messagesWordCount = allMessages.map(m => extractAllWords(m)).flat().length;
const averageMessageWordCount = messagesWordCount / allMessages.length;
const tokensPerWord = await getTokenCountAsync(allMessages.join('\n')) / messagesWordCount;
const tokensPerWord = await countSourceTokens(allMessages.join('\n')) / messagesWordCount;
const wordsPerToken = 1 / tokensPerWord;
const maxPromptLengthWords = Math.round(maxPromptLength * wordsPerToken);
// How many words should pass so that messages will start be dropped out of context;
@ -168,15 +204,15 @@ async function onPromptForceWordsAutoClick() {
async function onPromptIntervalAutoClick() {
const context = getContext();
const maxPromptLength = getMaxContextSize(extension_settings.memory.overrideResponseLength);
const maxPromptLength = await getSourceContextSize();
const chat = context.chat;
const allMessages = chat.filter(m => !m.is_system && m.mes).map(m => m.mes);
const messagesWordCount = allMessages.map(m => extractAllWords(m)).flat().length;
const messagesTokenCount = await getTokenCountAsync(allMessages.join('\n'));
const messagesTokenCount = await countSourceTokens(allMessages.join('\n'));
const tokensPerWord = messagesTokenCount / messagesWordCount;
const averageMessageTokenCount = messagesTokenCount / allMessages.length;
const targetSummaryTokens = Math.round(extension_settings.memory.promptWords * tokensPerWord);
const promptTokens = await getTokenCountAsync(extension_settings.memory.prompt);
const promptTokens = await countSourceTokens(extension_settings.memory.prompt);
const promptAllowance = maxPromptLength - promptTokens - targetSummaryTokens;
const maxMessagesPerSummary = extension_settings.memory.maxMessagesPerRequest || 0;
const averageMessagesPerPrompt = Math.floor(promptAllowance / averageMessageTokenCount);
@ -213,8 +249,8 @@ function onSummarySourceChange(event) {
function switchSourceControls(value) {
$('#memory_settings [data-summary-source]').each((_, element) => {
const source = $(element).data('summary-source');
$(element).toggle(source === value);
const source = element.dataset.summarySource.split(',').map(s => s.trim());
$(element).toggle(source.includes(value));
});
}
@ -359,6 +395,12 @@ async function onChatEvent() {
}
}
if (extension_settings.memory.source === summary_sources.webllm) {
if (!isWebLlmSupported()) {
return;
}
}
const context = getContext();
const chat = context.chat;
@ -431,8 +473,12 @@ async function forceSummarizeChat() {
return '';
}
toastr.info('Summarizing chat...', 'Please wait');
const value = await summarizeChatMain(context, true, skipWIAN);
const toast = toastr.info('Summarizing chat...', 'Please wait', { timeOut: 0, extendedTimeOut: 0 });
const value = extension_settings.memory.source === summary_sources.main
? await summarizeChatMain(context, true, skipWIAN)
: await summarizeChatWebLLM(context, true);
toastr.clear(toast);
if (!value) {
toastr.warning('Failed to summarize chat');
@ -484,16 +530,25 @@ async function summarizeChat(context) {
case summary_sources.main:
await summarizeChatMain(context, false, skipWIAN);
break;
case summary_sources.webllm:
await summarizeChatWebLLM(context, false);
break;
default:
break;
}
}
async function summarizeChatMain(context, force, skipWIAN) {
/**
* Check if the chat should be summarized based on the current conditions.
* Return summary prompt if it should be summarized.
* @param {any} context ST context
* @param {boolean} force Summarize the chat regardless of the conditions
* @returns {Promise<string>} Summary prompt or empty string
*/
async function getSummaryPromptForNow(context, force) {
if (extension_settings.memory.promptInterval === 0 && !force) {
console.debug('Prompt interval is set to 0, skipping summarization');
return;
return '';
}
try {
@ -505,17 +560,17 @@ async function summarizeChatMain(context, force, skipWIAN) {
waitUntilCondition(() => is_send_press === false, 30000, 100);
} catch {
console.debug('Timeout waiting for is_send_press');
return;
return '';
}
if (!context.chat.length) {
console.debug('No messages in chat to summarize');
return;
return '';
}
if (context.chat.length < extension_settings.memory.promptInterval && !force) {
console.debug(`Not enough messages in chat to summarize (chat: ${context.chat.length}, interval: ${extension_settings.memory.promptInterval})`);
return;
return '';
}
let messagesSinceLastSummary = 0;
@ -539,7 +594,7 @@ async function summarizeChatMain(context, force, skipWIAN) {
if (!conditionSatisfied && !force) {
console.debug(`Summary conditions not satisfied (messages: ${messagesSinceLastSummary}, interval: ${extension_settings.memory.promptInterval}, words: ${wordsSinceLastSummary}, force words: ${extension_settings.memory.promptForceWords})`);
return;
return '';
}
console.log('Summarizing chat, messages since last summary: ' + messagesSinceLastSummary, 'words since last summary: ' + wordsSinceLastSummary);
@ -547,6 +602,63 @@ async function summarizeChatMain(context, force, skipWIAN) {
if (!prompt) {
console.debug('Summarization prompt is empty. Skipping summarization.');
return '';
}
return prompt;
}
async function summarizeChatWebLLM(context, force) {
if (!isWebLlmSupported()) {
return;
}
const prompt = await getSummaryPromptForNow(context, force);
if (!prompt) {
return;
}
const { rawPrompt, lastUsedIndex } = await getRawSummaryPrompt(context, prompt);
if (lastUsedIndex === null || lastUsedIndex === -1) {
if (force) {
toastr.info('To try again, remove the latest summary.', 'No messages found to summarize');
}
return null;
}
const messages = [
{ role: 'system', content: prompt },
{ role: 'user', content: rawPrompt },
];
const params = {};
if (extension_settings.memory.overrideResponseLength > 0) {
params.max_tokens = extension_settings.memory.overrideResponseLength;
}
const summary = await generateWebLlmChatPrompt(messages, params);
const newContext = getContext();
// something changed during summarization request
if (newContext.groupId !== context.groupId ||
newContext.chatId !== context.chatId ||
(!newContext.groupId && (newContext.characterId !== context.characterId))) {
console.log('Context changed, summary discarded');
return;
}
setMemoryContext(summary, true, lastUsedIndex);
return summary;
}
async function summarizeChatMain(context, force, skipWIAN) {
const prompt = await getSummaryPromptForNow(context, force);
if (!prompt) {
return;
}
@ -634,7 +746,7 @@ async function getRawSummaryPrompt(context, prompt) {
chat.pop(); // We always exclude the last message from the buffer
const chatBuffer = [];
const PADDING = 64;
const PROMPT_SIZE = getMaxContextSize(extension_settings.memory.overrideResponseLength);
const PROMPT_SIZE = await getSourceContextSize();
let latestUsedMessage = null;
for (let index = latestSummaryIndex + 1; index < chat.length; index++) {
@ -651,7 +763,7 @@ async function getRawSummaryPrompt(context, prompt) {
const entry = `${message.name}:\n${message.mes}`;
chatBuffer.push(entry);
const tokens = await getTokenCountAsync(getMemoryString(true), PADDING);
const tokens = await countSourceTokens(getMemoryString(true), PADDING);
if (tokens > PROMPT_SIZE) {
chatBuffer.pop();

View File

@ -13,6 +13,7 @@
<select id="summary_source">
<option value="main" data-i18n="ext_sum_main_api">Main API</option>
<option value="extras">Extras API</option>
<option value="webllm" data-i18n="ext_sum_webllm">WebLLM Extension</option>
</select><br>
<div class="flex-container justifyspacebetween alignitemscenter">
@ -24,7 +25,7 @@
<textarea id="memory_contents" class="text_pole textarea_compact" rows="6" data-i18n="[placeholder]ext_sum_memory_placeholder" placeholder="Summary will be generated here..."></textarea>
<div class="memory_contents_controls">
<div id="memory_force_summarize" data-summary-source="main" class="menu_button menu_button_icon" title="Trigger a summary update right now." data-i18n="[title]ext_sum_force_tip">
<div id="memory_force_summarize" data-summary-source="main,webllm" class="menu_button menu_button_icon" title="Trigger a summary update right now." data-i18n="[title]ext_sum_force_tip">
<i class="fa-solid fa-database"></i>
<span data-i18n="ext_sum_force_text">Summarize now</span>
</div>
@ -58,7 +59,7 @@
<span data-i18n="ext_sum_prompt_builder_3">Classic, blocking</span>
</label>
</div>
<div data-summary-source="main">
<div data-summary-source="main,webllm">
<label for="memory_prompt" class="title_restorable">
<span data-i18n="Summary Prompt">Summary Prompt</span>
<div id="memory_prompt_restore" data-i18n="[title]ext_sum_restore_default_prompt_tip" title="Restore default prompt" class="right_menu_button">
@ -74,7 +75,7 @@
</label>
<input id="memory_override_response_length" type="range" value="{{defaultSettings.overrideResponseLength}}" min="{{defaultSettings.overrideResponseLengthMin}}" max="{{defaultSettings.overrideResponseLengthMax}}" step="{{defaultSettings.overrideResponseLengthStep}}" />
<label for="memory_max_messages_per_request">
<span data-i18n="ext_sum_raw_max_msg">[Raw] Max messages per request</span> (<span id="memory_max_messages_per_request_value"></span>)
<span data-i18n="ext_sum_raw_max_msg">[Raw/WebLLM] Max messages per request</span> (<span id="memory_max_messages_per_request_value"></span>)
<small class="memory_disabled_hint" data-i18n="ext_sum_0_unlimited">0 = unlimited</small>
</label>
<input id="memory_max_messages_per_request" type="range" value="{{defaultSettings.maxMessagesPerRequest}}" min="{{defaultSettings.maxMessagesPerRequestMin}}" max="{{defaultSettings.maxMessagesPerRequestMax}}" step="{{defaultSettings.maxMessagesPerRequestStep}}" />

View File

@ -183,32 +183,40 @@ function throwIfInvalidModel(useReverseProxy) {
*/
export function isWebLlmSupported() {
if (!('gpu' in navigator)) {
toastr.error('Your browser does not support the WebGPU API. Please use a different browser.', 'WebLLM', {
preventDuplicates: true,
timeOut: 0,
extendedTimeOut: 0,
});
const warningKey = 'webllm_browser_warning_shown';
if (!sessionStorage.getItem(warningKey)) {
toastr.error('Your browser does not support the WebGPU API. Please use a different browser.', 'WebLLM', {
preventDuplicates: true,
timeOut: 0,
extendedTimeOut: 0,
});
sessionStorage.setItem(warningKey, '1');
}
return false;
}
if (!('llm' in SillyTavern)) {
toastr.error('WebLLM extension is not installed. Click here to install it.', 'WebLLM', {
timeOut: 0,
extendedTimeOut: 0,
preventDuplicates: true,
onclick: () => {
const button = document.getElementById('third_party_extension_button');
if (button) {
button.click();
}
const warningKey = 'webllm_extension_warning_shown';
if (!sessionStorage.getItem(warningKey)) {
toastr.error('WebLLM extension is not installed. Click here to install it.', 'WebLLM', {
timeOut: 0,
extendedTimeOut: 0,
preventDuplicates: true,
onclick: () => {
const button = document.getElementById('third_party_extension_button');
if (button) {
button.click();
}
const input = document.querySelector('dialog textarea');
const input = document.querySelector('dialog textarea');
if (input instanceof HTMLTextAreaElement) {
input.value = 'https://github.com/SillyTavern/Extension-WebLLM';
}
},
});
if (input instanceof HTMLTextAreaElement) {
input.value = 'https://github.com/SillyTavern/Extension-WebLLM';
}
},
});
sessionStorage.setItem(warningKey, '1');
}
return false;
}
@ -218,15 +226,16 @@ export function isWebLlmSupported() {
/**
* Generates text in response to a chat prompt using WebLLM.
* @param {any[]} messages Messages to use for generating
* @param {object} params Additional parameters
* @returns {Promise<string>} Generated response
*/
export async function generateWebLlmChatPrompt(messages) {
export async function generateWebLlmChatPrompt(messages, params = {}) {
if (!isWebLlmSupported()) {
throw new Error('WebLLM extension is not installed.');
}
const engine = SillyTavern.llm;
const response = await engine.generateChatPrompt(messages);
const response = await engine.generateChatPrompt(messages, params);
return response;
}