diff --git a/public/scripts/extensions/memory/index.js b/public/scripts/extensions/memory/index.js index 5829ead28..452eea941 100644 --- a/public/scripts/extensions/memory/index.js +++ b/public/scripts/extensions/memory/index.js @@ -25,6 +25,7 @@ import { SlashCommandParser } from '../../slash-commands/SlashCommandParser.js'; import { SlashCommand } from '../../slash-commands/SlashCommand.js'; import { ARGUMENT_TYPE, SlashCommandArgument, SlashCommandNamedArgument } from '../../slash-commands/SlashCommandArgument.js'; import { MacrosParser } from '../../macros.js'; +import { countWebLlmTokens, generateWebLlmChatPrompt, getWebLlmContextSize, isWebLlmSupported } from '../shared.js'; export { MODULE_NAME }; const MODULE_NAME = '1_memory'; @@ -36,6 +37,40 @@ let lastMessageHash = null; let lastMessageId = null; let inApiCall = false; +/** + * Count the number of tokens in the provided text. + * @param {string} text Text to count tokens for + * @returns {Promise} Number of tokens in the text + */ +async function countSourceTokens(text, padding = 0) { + if (extension_settings.memory.source === summary_sources.webllm) { + const count = await countWebLlmTokens(text); + return count + padding; + } + + if (extension_settings.memory.source === summary_sources.extras) { + const count = getTextTokens(tokenizers.GPT2, text).length; + return count + padding; + } + + return await getTokenCountAsync(text, padding); +} + +async function getSourceContextSize() { + const overrideLength = extension_settings.memory.overrideResponseLength; + + if (extension_settings.memory.source === summary_sources.webllm) { + const maxContext = await getWebLlmContextSize(); + return overrideLength > 0 ? (maxContext - overrideLength) : Math.round(maxContext * 0.75); + } + + if (extension_settings.source === summary_sources.extras) { + return 1024; + } + + return getMaxContextSize(overrideLength); +} + const formatMemoryValue = function (value) { if (!value) { return ''; @@ -55,6 +90,7 @@ const saveChatDebounced = debounce(() => getContext().saveChat(), debounce_timeo const summary_sources = { 'extras': 'extras', 'main': 'main', + 'webllm': 'webllm', }; const prompt_builders = { @@ -130,12 +166,12 @@ function loadSettings() { async function onPromptForceWordsAutoClick() { const context = getContext(); - const maxPromptLength = getMaxContextSize(extension_settings.memory.overrideResponseLength); + const maxPromptLength = await getSourceContextSize(); const chat = context.chat; const allMessages = chat.filter(m => !m.is_system && m.mes).map(m => m.mes); const messagesWordCount = allMessages.map(m => extractAllWords(m)).flat().length; const averageMessageWordCount = messagesWordCount / allMessages.length; - const tokensPerWord = await getTokenCountAsync(allMessages.join('\n')) / messagesWordCount; + const tokensPerWord = await countSourceTokens(allMessages.join('\n')) / messagesWordCount; const wordsPerToken = 1 / tokensPerWord; const maxPromptLengthWords = Math.round(maxPromptLength * wordsPerToken); // How many words should pass so that messages will start be dropped out of context; @@ -168,15 +204,15 @@ async function onPromptForceWordsAutoClick() { async function onPromptIntervalAutoClick() { const context = getContext(); - const maxPromptLength = getMaxContextSize(extension_settings.memory.overrideResponseLength); + const maxPromptLength = await getSourceContextSize(); const chat = context.chat; const allMessages = chat.filter(m => !m.is_system && m.mes).map(m => m.mes); const messagesWordCount = allMessages.map(m => extractAllWords(m)).flat().length; - const messagesTokenCount = await getTokenCountAsync(allMessages.join('\n')); + const messagesTokenCount = await countSourceTokens(allMessages.join('\n')); const tokensPerWord = messagesTokenCount / messagesWordCount; const averageMessageTokenCount = messagesTokenCount / allMessages.length; const targetSummaryTokens = Math.round(extension_settings.memory.promptWords * tokensPerWord); - const promptTokens = await getTokenCountAsync(extension_settings.memory.prompt); + const promptTokens = await countSourceTokens(extension_settings.memory.prompt); const promptAllowance = maxPromptLength - promptTokens - targetSummaryTokens; const maxMessagesPerSummary = extension_settings.memory.maxMessagesPerRequest || 0; const averageMessagesPerPrompt = Math.floor(promptAllowance / averageMessageTokenCount); @@ -213,8 +249,8 @@ function onSummarySourceChange(event) { function switchSourceControls(value) { $('#memory_settings [data-summary-source]').each((_, element) => { - const source = $(element).data('summary-source'); - $(element).toggle(source === value); + const source = element.dataset.summarySource.split(',').map(s => s.trim()); + $(element).toggle(source.includes(value)); }); } @@ -359,6 +395,12 @@ async function onChatEvent() { } } + if (extension_settings.memory.source === summary_sources.webllm) { + if (!isWebLlmSupported()) { + return; + } + } + const context = getContext(); const chat = context.chat; @@ -431,8 +473,12 @@ async function forceSummarizeChat() { return ''; } - toastr.info('Summarizing chat...', 'Please wait'); - const value = await summarizeChatMain(context, true, skipWIAN); + const toast = toastr.info('Summarizing chat...', 'Please wait', { timeOut: 0, extendedTimeOut: 0 }); + const value = extension_settings.memory.source === summary_sources.main + ? await summarizeChatMain(context, true, skipWIAN) + : await summarizeChatWebLLM(context, true); + + toastr.clear(toast); if (!value) { toastr.warning('Failed to summarize chat'); @@ -484,16 +530,25 @@ async function summarizeChat(context) { case summary_sources.main: await summarizeChatMain(context, false, skipWIAN); break; + case summary_sources.webllm: + await summarizeChatWebLLM(context, false); + break; default: break; } } -async function summarizeChatMain(context, force, skipWIAN) { - +/** + * Check if the chat should be summarized based on the current conditions. + * Return summary prompt if it should be summarized. + * @param {any} context ST context + * @param {boolean} force Summarize the chat regardless of the conditions + * @returns {Promise} Summary prompt or empty string + */ +async function getSummaryPromptForNow(context, force) { if (extension_settings.memory.promptInterval === 0 && !force) { console.debug('Prompt interval is set to 0, skipping summarization'); - return; + return ''; } try { @@ -505,17 +560,17 @@ async function summarizeChatMain(context, force, skipWIAN) { waitUntilCondition(() => is_send_press === false, 30000, 100); } catch { console.debug('Timeout waiting for is_send_press'); - return; + return ''; } if (!context.chat.length) { console.debug('No messages in chat to summarize'); - return; + return ''; } if (context.chat.length < extension_settings.memory.promptInterval && !force) { console.debug(`Not enough messages in chat to summarize (chat: ${context.chat.length}, interval: ${extension_settings.memory.promptInterval})`); - return; + return ''; } let messagesSinceLastSummary = 0; @@ -539,7 +594,7 @@ async function summarizeChatMain(context, force, skipWIAN) { if (!conditionSatisfied && !force) { console.debug(`Summary conditions not satisfied (messages: ${messagesSinceLastSummary}, interval: ${extension_settings.memory.promptInterval}, words: ${wordsSinceLastSummary}, force words: ${extension_settings.memory.promptForceWords})`); - return; + return ''; } console.log('Summarizing chat, messages since last summary: ' + messagesSinceLastSummary, 'words since last summary: ' + wordsSinceLastSummary); @@ -547,6 +602,63 @@ async function summarizeChatMain(context, force, skipWIAN) { if (!prompt) { console.debug('Summarization prompt is empty. Skipping summarization.'); + return ''; + } + + return prompt; +} + +async function summarizeChatWebLLM(context, force) { + if (!isWebLlmSupported()) { + return; + } + + const prompt = await getSummaryPromptForNow(context, force); + + if (!prompt) { + return; + } + + const { rawPrompt, lastUsedIndex } = await getRawSummaryPrompt(context, prompt); + + if (lastUsedIndex === null || lastUsedIndex === -1) { + if (force) { + toastr.info('To try again, remove the latest summary.', 'No messages found to summarize'); + } + + return null; + } + + const messages = [ + { role: 'system', content: prompt }, + { role: 'user', content: rawPrompt }, + ]; + + const params = {}; + + if (extension_settings.memory.overrideResponseLength > 0) { + params.max_tokens = extension_settings.memory.overrideResponseLength; + } + + const summary = await generateWebLlmChatPrompt(messages, params); + const newContext = getContext(); + + // something changed during summarization request + if (newContext.groupId !== context.groupId || + newContext.chatId !== context.chatId || + (!newContext.groupId && (newContext.characterId !== context.characterId))) { + console.log('Context changed, summary discarded'); + return; + } + + setMemoryContext(summary, true, lastUsedIndex); + return summary; +} + +async function summarizeChatMain(context, force, skipWIAN) { + const prompt = await getSummaryPromptForNow(context, force); + + if (!prompt) { return; } @@ -634,7 +746,7 @@ async function getRawSummaryPrompt(context, prompt) { chat.pop(); // We always exclude the last message from the buffer const chatBuffer = []; const PADDING = 64; - const PROMPT_SIZE = getMaxContextSize(extension_settings.memory.overrideResponseLength); + const PROMPT_SIZE = await getSourceContextSize(); let latestUsedMessage = null; for (let index = latestSummaryIndex + 1; index < chat.length; index++) { @@ -651,7 +763,7 @@ async function getRawSummaryPrompt(context, prompt) { const entry = `${message.name}:\n${message.mes}`; chatBuffer.push(entry); - const tokens = await getTokenCountAsync(getMemoryString(true), PADDING); + const tokens = await countSourceTokens(getMemoryString(true), PADDING); if (tokens > PROMPT_SIZE) { chatBuffer.pop(); diff --git a/public/scripts/extensions/memory/settings.html b/public/scripts/extensions/memory/settings.html index f72a7f7b5..ac6b51728 100644 --- a/public/scripts/extensions/memory/settings.html +++ b/public/scripts/extensions/memory/settings.html @@ -13,6 +13,7 @@
@@ -24,7 +25,7 @@
- -
+