From 2bdc6f27cc95807bb954962eefdf359003849462 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Mon, 12 Aug 2024 21:56:32 +0300 Subject: [PATCH 1/8] Add SillyTavern globals --- .eslintrc.js | 1 + public/global.d.ts | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/.eslintrc.js b/.eslintrc.js index f05a857ad..8d9e36fe3 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -55,6 +55,7 @@ module.exports = { isProbablyReaderable: 'readonly', ePub: 'readonly', diff_match_patch: 'readonly', + SillyTavern: 'readonly', }, }, ], diff --git a/public/global.d.ts b/public/global.d.ts index 1a5ee091b..c8bfe14c2 100644 --- a/public/global.d.ts +++ b/public/global.d.ts @@ -14,6 +14,11 @@ declare var isProbablyReaderable; declare var ePub; declare var ai; +declare var SillyTavern: { + getContext(): any; + llm: any; +}; + // Jquery plugins interface JQuery { nanogallery2(options?: any): JQuery; From 77ab694ea0326f381a97057e320723d530aac74e Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Mon, 12 Aug 2024 22:07:44 +0300 Subject: [PATCH 2/8] Add shared utilities for generating text with WebLLM --- public/scripts/extensions/shared.js | 83 +++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/public/scripts/extensions/shared.js b/public/scripts/extensions/shared.js index 403d23f7f..793d66653 100644 --- a/public/scripts/extensions/shared.js +++ b/public/scripts/extensions/shared.js @@ -176,3 +176,86 @@ function throwIfInvalidModel(useReverseProxy) { throw new Error('Custom API URL is not set.'); } } + +/** + * Check if the WebLLM extension is installed and supported. + * @returns {boolean} Whether the extension is installed and supported + */ +export function isWebLlmSupported() { + if (!('gpu' in navigator)) { + toastr.error('Your browser does not support the WebGPU API. Please use a different browser.', 'WebLLM', { + preventDuplicates: true, + timeOut: 0, + extendedTimeOut: 0, + }); + return false; + } + + if (!('llm' in SillyTavern)) { + toastr.error('WebLLM extension is not installed. Click here to install it.', 'WebLLM', { + timeOut: 0, + extendedTimeOut: 0, + preventDuplicates: true, + onclick: () => { + const button = document.getElementById('third_party_extension_button'); + if (button) { + button.click(); + } + + const input = document.querySelector('dialog textarea'); + + if (input instanceof HTMLTextAreaElement) { + input.value = 'https://github.com/SillyTavern/Extension-WebLLM'; + } + }, + }); + return false; + } + + return true; +} + +/** + * Generates text in response to a chat prompt using WebLLM. + * @param {any[]} messages Messages to use for generating + * @returns {Promise} Generated response + */ +export async function generateWebLlmChatPrompt(messages) { + if (!isWebLlmSupported()) { + throw new Error('WebLLM extension is not installed.'); + } + + const engine = SillyTavern.llm; + const response = await engine.generateChatPrompt(messages); + return response; +} + +/** + * Counts the number of tokens in the provided text using WebLLM's default model. + * @param {string} text Text to count tokens in + * @returns {Promise} Number of tokens in the text + */ +export async function countWebLlmTokens(text) { + if (!isWebLlmSupported()) { + throw new Error('WebLLM extension is not installed.'); + } + + const engine = SillyTavern.llm; + const response = await engine.countTokens(text); + return response; +} + +/** + * Gets the size of the context in the WebLLM's default model. + * @returns {Promise} Size of the context in the WebLLM model + */ +export async function getWebLlmContextSize() { + if (!isWebLlmSupported()) { + throw new Error('WebLLM extension is not installed.'); + } + + const engine = SillyTavern.llm; + await engine.loadModel(); + const model = await engine.getCurrentModelInfo(); + return model?.context_size; +} From 8685c2f471bda4784ed52600620bd7019216a293 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Mon, 12 Aug 2024 23:01:03 +0300 Subject: [PATCH 3/8] Add WebLLM extension summarization --- public/scripts/extensions/memory/index.js | 148 +++++++++++++++--- .../scripts/extensions/memory/settings.html | 7 +- public/scripts/extensions/shared.js | 53 ++++--- 3 files changed, 165 insertions(+), 43 deletions(-) diff --git a/public/scripts/extensions/memory/index.js b/public/scripts/extensions/memory/index.js index 5829ead28..452eea941 100644 --- a/public/scripts/extensions/memory/index.js +++ b/public/scripts/extensions/memory/index.js @@ -25,6 +25,7 @@ import { SlashCommandParser } from '../../slash-commands/SlashCommandParser.js'; import { SlashCommand } from '../../slash-commands/SlashCommand.js'; import { ARGUMENT_TYPE, SlashCommandArgument, SlashCommandNamedArgument } from '../../slash-commands/SlashCommandArgument.js'; import { MacrosParser } from '../../macros.js'; +import { countWebLlmTokens, generateWebLlmChatPrompt, getWebLlmContextSize, isWebLlmSupported } from '../shared.js'; export { MODULE_NAME }; const MODULE_NAME = '1_memory'; @@ -36,6 +37,40 @@ let lastMessageHash = null; let lastMessageId = null; let inApiCall = false; +/** + * Count the number of tokens in the provided text. + * @param {string} text Text to count tokens for + * @returns {Promise} Number of tokens in the text + */ +async function countSourceTokens(text, padding = 0) { + if (extension_settings.memory.source === summary_sources.webllm) { + const count = await countWebLlmTokens(text); + return count + padding; + } + + if (extension_settings.memory.source === summary_sources.extras) { + const count = getTextTokens(tokenizers.GPT2, text).length; + return count + padding; + } + + return await getTokenCountAsync(text, padding); +} + +async function getSourceContextSize() { + const overrideLength = extension_settings.memory.overrideResponseLength; + + if (extension_settings.memory.source === summary_sources.webllm) { + const maxContext = await getWebLlmContextSize(); + return overrideLength > 0 ? (maxContext - overrideLength) : Math.round(maxContext * 0.75); + } + + if (extension_settings.source === summary_sources.extras) { + return 1024; + } + + return getMaxContextSize(overrideLength); +} + const formatMemoryValue = function (value) { if (!value) { return ''; @@ -55,6 +90,7 @@ const saveChatDebounced = debounce(() => getContext().saveChat(), debounce_timeo const summary_sources = { 'extras': 'extras', 'main': 'main', + 'webllm': 'webllm', }; const prompt_builders = { @@ -130,12 +166,12 @@ function loadSettings() { async function onPromptForceWordsAutoClick() { const context = getContext(); - const maxPromptLength = getMaxContextSize(extension_settings.memory.overrideResponseLength); + const maxPromptLength = await getSourceContextSize(); const chat = context.chat; const allMessages = chat.filter(m => !m.is_system && m.mes).map(m => m.mes); const messagesWordCount = allMessages.map(m => extractAllWords(m)).flat().length; const averageMessageWordCount = messagesWordCount / allMessages.length; - const tokensPerWord = await getTokenCountAsync(allMessages.join('\n')) / messagesWordCount; + const tokensPerWord = await countSourceTokens(allMessages.join('\n')) / messagesWordCount; const wordsPerToken = 1 / tokensPerWord; const maxPromptLengthWords = Math.round(maxPromptLength * wordsPerToken); // How many words should pass so that messages will start be dropped out of context; @@ -168,15 +204,15 @@ async function onPromptForceWordsAutoClick() { async function onPromptIntervalAutoClick() { const context = getContext(); - const maxPromptLength = getMaxContextSize(extension_settings.memory.overrideResponseLength); + const maxPromptLength = await getSourceContextSize(); const chat = context.chat; const allMessages = chat.filter(m => !m.is_system && m.mes).map(m => m.mes); const messagesWordCount = allMessages.map(m => extractAllWords(m)).flat().length; - const messagesTokenCount = await getTokenCountAsync(allMessages.join('\n')); + const messagesTokenCount = await countSourceTokens(allMessages.join('\n')); const tokensPerWord = messagesTokenCount / messagesWordCount; const averageMessageTokenCount = messagesTokenCount / allMessages.length; const targetSummaryTokens = Math.round(extension_settings.memory.promptWords * tokensPerWord); - const promptTokens = await getTokenCountAsync(extension_settings.memory.prompt); + const promptTokens = await countSourceTokens(extension_settings.memory.prompt); const promptAllowance = maxPromptLength - promptTokens - targetSummaryTokens; const maxMessagesPerSummary = extension_settings.memory.maxMessagesPerRequest || 0; const averageMessagesPerPrompt = Math.floor(promptAllowance / averageMessageTokenCount); @@ -213,8 +249,8 @@ function onSummarySourceChange(event) { function switchSourceControls(value) { $('#memory_settings [data-summary-source]').each((_, element) => { - const source = $(element).data('summary-source'); - $(element).toggle(source === value); + const source = element.dataset.summarySource.split(',').map(s => s.trim()); + $(element).toggle(source.includes(value)); }); } @@ -359,6 +395,12 @@ async function onChatEvent() { } } + if (extension_settings.memory.source === summary_sources.webllm) { + if (!isWebLlmSupported()) { + return; + } + } + const context = getContext(); const chat = context.chat; @@ -431,8 +473,12 @@ async function forceSummarizeChat() { return ''; } - toastr.info('Summarizing chat...', 'Please wait'); - const value = await summarizeChatMain(context, true, skipWIAN); + const toast = toastr.info('Summarizing chat...', 'Please wait', { timeOut: 0, extendedTimeOut: 0 }); + const value = extension_settings.memory.source === summary_sources.main + ? await summarizeChatMain(context, true, skipWIAN) + : await summarizeChatWebLLM(context, true); + + toastr.clear(toast); if (!value) { toastr.warning('Failed to summarize chat'); @@ -484,16 +530,25 @@ async function summarizeChat(context) { case summary_sources.main: await summarizeChatMain(context, false, skipWIAN); break; + case summary_sources.webllm: + await summarizeChatWebLLM(context, false); + break; default: break; } } -async function summarizeChatMain(context, force, skipWIAN) { - +/** + * Check if the chat should be summarized based on the current conditions. + * Return summary prompt if it should be summarized. + * @param {any} context ST context + * @param {boolean} force Summarize the chat regardless of the conditions + * @returns {Promise} Summary prompt or empty string + */ +async function getSummaryPromptForNow(context, force) { if (extension_settings.memory.promptInterval === 0 && !force) { console.debug('Prompt interval is set to 0, skipping summarization'); - return; + return ''; } try { @@ -505,17 +560,17 @@ async function summarizeChatMain(context, force, skipWIAN) { waitUntilCondition(() => is_send_press === false, 30000, 100); } catch { console.debug('Timeout waiting for is_send_press'); - return; + return ''; } if (!context.chat.length) { console.debug('No messages in chat to summarize'); - return; + return ''; } if (context.chat.length < extension_settings.memory.promptInterval && !force) { console.debug(`Not enough messages in chat to summarize (chat: ${context.chat.length}, interval: ${extension_settings.memory.promptInterval})`); - return; + return ''; } let messagesSinceLastSummary = 0; @@ -539,7 +594,7 @@ async function summarizeChatMain(context, force, skipWIAN) { if (!conditionSatisfied && !force) { console.debug(`Summary conditions not satisfied (messages: ${messagesSinceLastSummary}, interval: ${extension_settings.memory.promptInterval}, words: ${wordsSinceLastSummary}, force words: ${extension_settings.memory.promptForceWords})`); - return; + return ''; } console.log('Summarizing chat, messages since last summary: ' + messagesSinceLastSummary, 'words since last summary: ' + wordsSinceLastSummary); @@ -547,6 +602,63 @@ async function summarizeChatMain(context, force, skipWIAN) { if (!prompt) { console.debug('Summarization prompt is empty. Skipping summarization.'); + return ''; + } + + return prompt; +} + +async function summarizeChatWebLLM(context, force) { + if (!isWebLlmSupported()) { + return; + } + + const prompt = await getSummaryPromptForNow(context, force); + + if (!prompt) { + return; + } + + const { rawPrompt, lastUsedIndex } = await getRawSummaryPrompt(context, prompt); + + if (lastUsedIndex === null || lastUsedIndex === -1) { + if (force) { + toastr.info('To try again, remove the latest summary.', 'No messages found to summarize'); + } + + return null; + } + + const messages = [ + { role: 'system', content: prompt }, + { role: 'user', content: rawPrompt }, + ]; + + const params = {}; + + if (extension_settings.memory.overrideResponseLength > 0) { + params.max_tokens = extension_settings.memory.overrideResponseLength; + } + + const summary = await generateWebLlmChatPrompt(messages, params); + const newContext = getContext(); + + // something changed during summarization request + if (newContext.groupId !== context.groupId || + newContext.chatId !== context.chatId || + (!newContext.groupId && (newContext.characterId !== context.characterId))) { + console.log('Context changed, summary discarded'); + return; + } + + setMemoryContext(summary, true, lastUsedIndex); + return summary; +} + +async function summarizeChatMain(context, force, skipWIAN) { + const prompt = await getSummaryPromptForNow(context, force); + + if (!prompt) { return; } @@ -634,7 +746,7 @@ async function getRawSummaryPrompt(context, prompt) { chat.pop(); // We always exclude the last message from the buffer const chatBuffer = []; const PADDING = 64; - const PROMPT_SIZE = getMaxContextSize(extension_settings.memory.overrideResponseLength); + const PROMPT_SIZE = await getSourceContextSize(); let latestUsedMessage = null; for (let index = latestSummaryIndex + 1; index < chat.length; index++) { @@ -651,7 +763,7 @@ async function getRawSummaryPrompt(context, prompt) { const entry = `${message.name}:\n${message.mes}`; chatBuffer.push(entry); - const tokens = await getTokenCountAsync(getMemoryString(true), PADDING); + const tokens = await countSourceTokens(getMemoryString(true), PADDING); if (tokens > PROMPT_SIZE) { chatBuffer.pop(); diff --git a/public/scripts/extensions/memory/settings.html b/public/scripts/extensions/memory/settings.html index f72a7f7b5..ac6b51728 100644 --- a/public/scripts/extensions/memory/settings.html +++ b/public/scripts/extensions/memory/settings.html @@ -13,6 +13,7 @@
@@ -24,7 +25,7 @@
- -
+
From 8921b78f8788a4a605ee10cfa83a457c941b8895 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Tue, 13 Aug 2024 19:57:38 +0300 Subject: [PATCH 8/8] Add debug logs to WebLLM completions --- public/scripts/extensions/shared.js | 2 ++ 1 file changed, 2 insertions(+) diff --git a/public/scripts/extensions/shared.js b/public/scripts/extensions/shared.js index b50908940..706acea62 100644 --- a/public/scripts/extensions/shared.js +++ b/public/scripts/extensions/shared.js @@ -223,8 +223,10 @@ export async function generateWebLlmChatPrompt(messages, params = {}) { throw new Error('WebLLM extension is not installed.'); } + console.debug('WebLLM chat completion request:', messages, params); const engine = SillyTavern.llm; const response = await engine.generateChatPrompt(messages, params); + console.debug('WebLLM chat completion response:', response); return response; }