Merge pull request #2651 from SillyTavern/webgpu-summary

Summarize with WebLLM extension
This commit is contained in:
Cohee
2024-08-13 20:59:55 +03:00
committed by GitHub
9 changed files with 310 additions and 52 deletions

View File

@@ -55,6 +55,7 @@ module.exports = {
isProbablyReaderable: 'readonly',
ePub: 'readonly',
diff_match_patch: 'readonly',
SillyTavern: 'readonly',
},
},
],

5
public/global.d.ts vendored
View File

@@ -14,6 +14,11 @@ declare var isProbablyReaderable;
declare var ePub;
declare var ai;
declare var SillyTavern: {
getContext(): any;
llm: any;
};
// Jquery plugins
interface JQuery {
nanogallery2(options?: any): JQuery;

View File

@@ -989,6 +989,28 @@ export async function writeExtensionField(characterId, key, value) {
}
}
/**
* Prompts the user to enter the Git URL of the extension to import.
* After obtaining the Git URL, makes a POST request to '/api/extensions/install' to import the extension.
* If the extension is imported successfully, a success message is displayed.
* If the extension import fails, an error message is displayed and the error is logged to the console.
* After successfully importing the extension, the extension settings are reloaded and a 'EXTENSION_SETTINGS_LOADED' event is emitted.
* @param {string} [suggestUrl] Suggested URL to install
* @returns {Promise<void>}
*/
export async function openThirdPartyExtensionMenu(suggestUrl = '') {
const html = await renderTemplateAsync('installExtension');
const input = await callGenericPopup(html, POPUP_TYPE.INPUT, suggestUrl ?? '');
if (!input) {
console.debug('Extension install cancelled');
return;
}
const url = String(input).trim();
await installExtension(url);
}
jQuery(async function () {
await addExtensionsButtonAndMenu();
$('#extensionsMenuButton').css('display', 'flex');
@@ -1004,28 +1026,8 @@ jQuery(async function () {
/**
* Handles the click event for the third-party extension import button.
* Prompts the user to enter the Git URL of the extension to import.
* After obtaining the Git URL, makes a POST request to '/api/extensions/install' to import the extension.
* If the extension is imported successfully, a success message is displayed.
* If the extension import fails, an error message is displayed and the error is logged to the console.
* After successfully importing the extension, the extension settings are reloaded and a 'EXTENSION_SETTINGS_LOADED' event is emitted.
*
* @listens #third_party_extension_button#click - The click event of the '#third_party_extension_button' element.
*/
$('#third_party_extension_button').on('click', async () => {
const html = `<h3>Enter the Git URL of the extension to install</h3>
<br>
<p><b>Disclaimer:</b> Please be aware that using external extensions can have unintended side effects and may pose security risks. Always make sure you trust the source before importing an extension. We are not responsible for any damage caused by third-party extensions.</p>
<br>
<p>Example: <tt> https://github.com/author/extension-name </tt></p>`;
const input = await callGenericPopup(html, POPUP_TYPE.INPUT, '');
if (!input) {
console.debug('Extension install cancelled');
return;
}
const url = String(input).trim();
await installExtension(url);
});
$('#third_party_extension_button').on('click', () => openThirdPartyExtensionMenu());
});

View File

@@ -25,6 +25,7 @@ import { SlashCommandParser } from '../../slash-commands/SlashCommandParser.js';
import { SlashCommand } from '../../slash-commands/SlashCommand.js';
import { ARGUMENT_TYPE, SlashCommandArgument, SlashCommandNamedArgument } from '../../slash-commands/SlashCommandArgument.js';
import { MacrosParser } from '../../macros.js';
import { countWebLlmTokens, generateWebLlmChatPrompt, getWebLlmContextSize, isWebLlmSupported } from '../shared.js';
export { MODULE_NAME };
const MODULE_NAME = '1_memory';
@@ -36,6 +37,41 @@ let lastMessageHash = null;
let lastMessageId = null;
let inApiCall = false;
/**
* Count the number of tokens in the provided text.
* @param {string} text Text to count tokens for
* @param {number} padding Number of additional tokens to add to the count
* @returns {Promise<number>} Number of tokens in the text
*/
async function countSourceTokens(text, padding = 0) {
if (extension_settings.memory.source === summary_sources.webllm) {
const count = await countWebLlmTokens(text);
return count + padding;
}
if (extension_settings.memory.source === summary_sources.extras) {
const count = getTextTokens(tokenizers.GPT2, text).length;
return count + padding;
}
return await getTokenCountAsync(text, padding);
}
async function getSourceContextSize() {
const overrideLength = extension_settings.memory.overrideResponseLength;
if (extension_settings.memory.source === summary_sources.webllm) {
const maxContext = await getWebLlmContextSize();
return overrideLength > 0 ? (maxContext - overrideLength) : Math.round(maxContext * 0.75);
}
if (extension_settings.source === summary_sources.extras) {
return 1024 - 64;
}
return getMaxContextSize(overrideLength);
}
const formatMemoryValue = function (value) {
if (!value) {
return '';
@@ -55,6 +91,7 @@ const saveChatDebounced = debounce(() => getContext().saveChat(), debounce_timeo
const summary_sources = {
'extras': 'extras',
'main': 'main',
'webllm': 'webllm',
};
const prompt_builders = {
@@ -130,12 +167,12 @@ function loadSettings() {
async function onPromptForceWordsAutoClick() {
const context = getContext();
const maxPromptLength = getMaxContextSize(extension_settings.memory.overrideResponseLength);
const maxPromptLength = await getSourceContextSize();
const chat = context.chat;
const allMessages = chat.filter(m => !m.is_system && m.mes).map(m => m.mes);
const messagesWordCount = allMessages.map(m => extractAllWords(m)).flat().length;
const averageMessageWordCount = messagesWordCount / allMessages.length;
const tokensPerWord = await getTokenCountAsync(allMessages.join('\n')) / messagesWordCount;
const tokensPerWord = await countSourceTokens(allMessages.join('\n')) / messagesWordCount;
const wordsPerToken = 1 / tokensPerWord;
const maxPromptLengthWords = Math.round(maxPromptLength * wordsPerToken);
// How many words should pass so that messages will start be dropped out of context;
@@ -168,15 +205,15 @@ async function onPromptForceWordsAutoClick() {
async function onPromptIntervalAutoClick() {
const context = getContext();
const maxPromptLength = getMaxContextSize(extension_settings.memory.overrideResponseLength);
const maxPromptLength = await getSourceContextSize();
const chat = context.chat;
const allMessages = chat.filter(m => !m.is_system && m.mes).map(m => m.mes);
const messagesWordCount = allMessages.map(m => extractAllWords(m)).flat().length;
const messagesTokenCount = await getTokenCountAsync(allMessages.join('\n'));
const messagesTokenCount = await countSourceTokens(allMessages.join('\n'));
const tokensPerWord = messagesTokenCount / messagesWordCount;
const averageMessageTokenCount = messagesTokenCount / allMessages.length;
const targetSummaryTokens = Math.round(extension_settings.memory.promptWords * tokensPerWord);
const promptTokens = await getTokenCountAsync(extension_settings.memory.prompt);
const promptTokens = await countSourceTokens(extension_settings.memory.prompt);
const promptAllowance = maxPromptLength - promptTokens - targetSummaryTokens;
const maxMessagesPerSummary = extension_settings.memory.maxMessagesPerRequest || 0;
const averageMessagesPerPrompt = Math.floor(promptAllowance / averageMessageTokenCount);
@@ -213,8 +250,8 @@ function onSummarySourceChange(event) {
function switchSourceControls(value) {
$('#memory_settings [data-summary-source]').each((_, element) => {
const source = $(element).data('summary-source');
$(element).toggle(source === value);
const source = element.dataset.summarySource.split(',').map(s => s.trim());
$(element).toggle(source.includes(value));
});
}
@@ -353,10 +390,13 @@ function getIndexOfLatestChatSummary(chat) {
async function onChatEvent() {
// Module not enabled
if (extension_settings.memory.source === summary_sources.extras) {
if (!modules.includes('summarize')) {
return;
}
if (extension_settings.memory.source === summary_sources.extras && !modules.includes('summarize')) {
return;
}
// WebLLM is not supported
if (extension_settings.memory.source === summary_sources.webllm && !isWebLlmSupported()) {
return;
}
const context = getContext();
@@ -431,8 +471,12 @@ async function forceSummarizeChat() {
return '';
}
toastr.info('Summarizing chat...', 'Please wait');
const value = await summarizeChatMain(context, true, skipWIAN);
const toast = toastr.info('Summarizing chat...', 'Please wait', { timeOut: 0, extendedTimeOut: 0 });
const value = extension_settings.memory.source === summary_sources.main
? await summarizeChatMain(context, true, skipWIAN)
: await summarizeChatWebLLM(context, true);
toastr.clear(toast);
if (!value) {
toastr.warning('Failed to summarize chat');
@@ -464,6 +508,11 @@ async function summarizeCallback(args, text) {
return await callExtrasSummarizeAPI(text);
case summary_sources.main:
return await generateRaw(text, '', false, false, prompt, extension_settings.memory.overrideResponseLength);
case summary_sources.webllm: {
const messages = [{ role: 'system', content: prompt }, { role: 'user', content: text }].filter(m => m.content);
const params = extension_settings.memory.overrideResponseLength > 0 ? { max_tokens: extension_settings.memory.overrideResponseLength } : {};
return await generateWebLlmChatPrompt(messages, params);
}
default:
toastr.warning('Invalid summarization source specified');
return '';
@@ -484,16 +533,25 @@ async function summarizeChat(context) {
case summary_sources.main:
await summarizeChatMain(context, false, skipWIAN);
break;
case summary_sources.webllm:
await summarizeChatWebLLM(context, false);
break;
default:
break;
}
}
async function summarizeChatMain(context, force, skipWIAN) {
/**
* Check if the chat should be summarized based on the current conditions.
* Return summary prompt if it should be summarized.
* @param {any} context ST context
* @param {boolean} force Summarize the chat regardless of the conditions
* @returns {Promise<string>} Summary prompt or empty string
*/
async function getSummaryPromptForNow(context, force) {
if (extension_settings.memory.promptInterval === 0 && !force) {
console.debug('Prompt interval is set to 0, skipping summarization');
return;
return '';
}
try {
@@ -505,17 +563,17 @@ async function summarizeChatMain(context, force, skipWIAN) {
waitUntilCondition(() => is_send_press === false, 30000, 100);
} catch {
console.debug('Timeout waiting for is_send_press');
return;
return '';
}
if (!context.chat.length) {
console.debug('No messages in chat to summarize');
return;
return '';
}
if (context.chat.length < extension_settings.memory.promptInterval && !force) {
console.debug(`Not enough messages in chat to summarize (chat: ${context.chat.length}, interval: ${extension_settings.memory.promptInterval})`);
return;
return '';
}
let messagesSinceLastSummary = 0;
@@ -539,7 +597,7 @@ async function summarizeChatMain(context, force, skipWIAN) {
if (!conditionSatisfied && !force) {
console.debug(`Summary conditions not satisfied (messages: ${messagesSinceLastSummary}, interval: ${extension_settings.memory.promptInterval}, words: ${wordsSinceLastSummary}, force words: ${extension_settings.memory.promptForceWords})`);
return;
return '';
}
console.log('Summarizing chat, messages since last summary: ' + messagesSinceLastSummary, 'words since last summary: ' + wordsSinceLastSummary);
@@ -547,6 +605,63 @@ async function summarizeChatMain(context, force, skipWIAN) {
if (!prompt) {
console.debug('Summarization prompt is empty. Skipping summarization.');
return '';
}
return prompt;
}
async function summarizeChatWebLLM(context, force) {
if (!isWebLlmSupported()) {
return;
}
const prompt = await getSummaryPromptForNow(context, force);
if (!prompt) {
return;
}
const { rawPrompt, lastUsedIndex } = await getRawSummaryPrompt(context, prompt);
if (lastUsedIndex === null || lastUsedIndex === -1) {
if (force) {
toastr.info('To try again, remove the latest summary.', 'No messages found to summarize');
}
return null;
}
const messages = [
{ role: 'system', content: prompt },
{ role: 'user', content: rawPrompt },
];
const params = {};
if (extension_settings.memory.overrideResponseLength > 0) {
params.max_tokens = extension_settings.memory.overrideResponseLength;
}
const summary = await generateWebLlmChatPrompt(messages, params);
const newContext = getContext();
// something changed during summarization request
if (newContext.groupId !== context.groupId ||
newContext.chatId !== context.chatId ||
(!newContext.groupId && (newContext.characterId !== context.characterId))) {
console.log('Context changed, summary discarded');
return;
}
setMemoryContext(summary, true, lastUsedIndex);
return summary;
}
async function summarizeChatMain(context, force, skipWIAN) {
const prompt = await getSummaryPromptForNow(context, force);
if (!prompt) {
return;
}
@@ -634,7 +749,7 @@ async function getRawSummaryPrompt(context, prompt) {
chat.pop(); // We always exclude the last message from the buffer
const chatBuffer = [];
const PADDING = 64;
const PROMPT_SIZE = getMaxContextSize(extension_settings.memory.overrideResponseLength);
const PROMPT_SIZE = await getSourceContextSize();
let latestUsedMessage = null;
for (let index = latestSummaryIndex + 1; index < chat.length; index++) {
@@ -651,7 +766,7 @@ async function getRawSummaryPrompt(context, prompt) {
const entry = `${message.name}:\n${message.mes}`;
chatBuffer.push(entry);
const tokens = await getTokenCountAsync(getMemoryString(true), PADDING);
const tokens = await countSourceTokens(getMemoryString(true), PADDING);
if (tokens > PROMPT_SIZE) {
chatBuffer.pop();
@@ -680,7 +795,7 @@ async function summarizeChatExtras(context) {
const reversedChat = chat.slice().reverse();
reversedChat.shift();
const memoryBuffer = [];
const CONTEXT_SIZE = 1024 - 64;
const CONTEXT_SIZE = await getSourceContextSize();
for (const message of reversedChat) {
// we reached the point of latest memory
@@ -698,14 +813,14 @@ async function summarizeChatExtras(context) {
memoryBuffer.push(entry);
// check if token limit was reached
const tokens = getTextTokens(tokenizers.GPT2, getMemoryString()).length;
const tokens = await countSourceTokens(getMemoryString());
if (tokens >= CONTEXT_SIZE) {
break;
}
}
const resultingString = getMemoryString();
const resultingTokens = getTextTokens(tokenizers.GPT2, resultingString).length;
const resultingTokens = await countSourceTokens(resultingString);
if (!resultingString || resultingTokens < CONTEXT_SIZE) {
console.debug('Not enough context to summarize');
@@ -933,7 +1048,7 @@ jQuery(async function () {
name: 'summarize',
callback: summarizeCallback,
namedArgumentList: [
new SlashCommandNamedArgument('source', 'API to use for summarization', [ARGUMENT_TYPE.STRING], false, false, '', ['main', 'extras']),
new SlashCommandNamedArgument('source', 'API to use for summarization', [ARGUMENT_TYPE.STRING], false, false, '', Object.values(summary_sources)),
SlashCommandNamedArgument.fromProps({
name: 'prompt',
description: 'prompt to use for summarization',

View File

@@ -13,6 +13,7 @@
<select id="summary_source">
<option value="main" data-i18n="ext_sum_main_api">Main API</option>
<option value="extras">Extras API</option>
<option value="webllm" data-i18n="ext_sum_webllm">WebLLM Extension</option>
</select><br>
<div class="flex-container justifyspacebetween alignitemscenter">
@@ -24,7 +25,7 @@
<textarea id="memory_contents" class="text_pole textarea_compact" rows="6" data-i18n="[placeholder]ext_sum_memory_placeholder" placeholder="Summary will be generated here..."></textarea>
<div class="memory_contents_controls">
<div id="memory_force_summarize" data-summary-source="main" class="menu_button menu_button_icon" title="Trigger a summary update right now." data-i18n="[title]ext_sum_force_tip">
<div id="memory_force_summarize" data-summary-source="main,webllm" class="menu_button menu_button_icon" title="Trigger a summary update right now." data-i18n="[title]ext_sum_force_tip">
<i class="fa-solid fa-database"></i>
<span data-i18n="ext_sum_force_text">Summarize now</span>
</div>
@@ -58,7 +59,7 @@
<span data-i18n="ext_sum_prompt_builder_3">Classic, blocking</span>
</label>
</div>
<div data-summary-source="main">
<div data-summary-source="main,webllm">
<label for="memory_prompt" class="title_restorable">
<span data-i18n="Summary Prompt">Summary Prompt</span>
<div id="memory_prompt_restore" data-i18n="[title]ext_sum_restore_default_prompt_tip" title="Restore default prompt" class="right_menu_button">
@@ -74,7 +75,7 @@
</label>
<input id="memory_override_response_length" type="range" value="{{defaultSettings.overrideResponseLength}}" min="{{defaultSettings.overrideResponseLengthMin}}" max="{{defaultSettings.overrideResponseLengthMax}}" step="{{defaultSettings.overrideResponseLengthStep}}" />
<label for="memory_max_messages_per_request">
<span data-i18n="ext_sum_raw_max_msg">[Raw] Max messages per request</span> (<span id="memory_max_messages_per_request_value"></span>)
<span data-i18n="ext_sum_raw_max_msg">[Raw/WebLLM] Max messages per request</span> (<span id="memory_max_messages_per_request_value"></span>)
<small class="memory_disabled_hint" data-i18n="ext_sum_0_unlimited">0 = unlimited</small>
</label>
<input id="memory_max_messages_per_request" type="range" value="{{defaultSettings.maxMessagesPerRequest}}" min="{{defaultSettings.maxMessagesPerRequestMin}}" max="{{defaultSettings.maxMessagesPerRequestMax}}" step="{{defaultSettings.maxMessagesPerRequestStep}}" />

View File

@@ -1,5 +1,5 @@
import { getRequestHeaders } from '../../script.js';
import { extension_settings } from '../extensions.js';
import { extension_settings, openThirdPartyExtensionMenu } from '../extensions.js';
import { oai_settings } from '../openai.js';
import { SECRET_KEYS, secret_state } from '../secrets.js';
import { textgen_types, textgenerationwebui_settings } from '../textgen-settings.js';
@@ -176,3 +176,86 @@ function throwIfInvalidModel(useReverseProxy) {
throw new Error('Custom API URL is not set.');
}
}
/**
* Check if the WebLLM extension is installed and supported.
* @returns {boolean} Whether the extension is installed and supported
*/
export function isWebLlmSupported() {
if (!('gpu' in navigator)) {
const warningKey = 'webllm_browser_warning_shown';
if (!sessionStorage.getItem(warningKey)) {
toastr.error('Your browser does not support the WebGPU API. Please use a different browser.', 'WebLLM', {
preventDuplicates: true,
timeOut: 0,
extendedTimeOut: 0,
});
sessionStorage.setItem(warningKey, '1');
}
return false;
}
if (!('llm' in SillyTavern)) {
const warningKey = 'webllm_extension_warning_shown';
if (!sessionStorage.getItem(warningKey)) {
toastr.error('WebLLM extension is not installed. Click here to install it.', 'WebLLM', {
timeOut: 0,
extendedTimeOut: 0,
preventDuplicates: true,
onclick: () => openThirdPartyExtensionMenu('https://github.com/SillyTavern/Extension-WebLLM'),
});
sessionStorage.setItem(warningKey, '1');
}
return false;
}
return true;
}
/**
* Generates text in response to a chat prompt using WebLLM.
* @param {any[]} messages Messages to use for generating
* @param {object} params Additional parameters
* @returns {Promise<string>} Generated response
*/
export async function generateWebLlmChatPrompt(messages, params = {}) {
if (!isWebLlmSupported()) {
throw new Error('WebLLM extension is not installed.');
}
console.debug('WebLLM chat completion request:', messages, params);
const engine = SillyTavern.llm;
const response = await engine.generateChatPrompt(messages, params);
console.debug('WebLLM chat completion response:', response);
return response;
}
/**
* Counts the number of tokens in the provided text using WebLLM's default model.
* @param {string} text Text to count tokens in
* @returns {Promise<number>} Number of tokens in the text
*/
export async function countWebLlmTokens(text) {
if (!isWebLlmSupported()) {
throw new Error('WebLLM extension is not installed.');
}
const engine = SillyTavern.llm;
const response = await engine.countTokens(text);
return response;
}
/**
* Gets the size of the context in the WebLLM's default model.
* @returns {Promise<number>} Size of the context in the WebLLM model
*/
export async function getWebLlmContextSize() {
if (!isWebLlmSupported()) {
throw new Error('WebLLM extension is not installed.');
}
const engine = SillyTavern.llm;
await engine.loadModel();
const model = await engine.getCurrentModelInfo();
return model?.context_size;
}

View File

@@ -31,6 +31,12 @@ import { SlashCommandParser } from '../../slash-commands/SlashCommandParser.js';
import { SlashCommand } from '../../slash-commands/SlashCommand.js';
import { ARGUMENT_TYPE, SlashCommandArgument, SlashCommandNamedArgument } from '../../slash-commands/SlashCommandArgument.js';
import { callGenericPopup, POPUP_RESULT, POPUP_TYPE } from '../../popup.js';
import { generateWebLlmChatPrompt, isWebLlmSupported } from '../shared.js';
/**
* @typedef {object} HashedMessage
* @property {string} text - The hashed message text
*/
const MODULE_NAME = 'vectors';
@@ -192,6 +198,11 @@ function splitByChunks(items) {
return chunkedItems;
}
/**
* Summarizes messages using the Extras API method.
* @param {HashedMessage[]} hashedMessages Array of hashed messages
* @returns {Promise<HashedMessage[]>} Summarized messages
*/
async function summarizeExtra(hashedMessages) {
for (const element of hashedMessages) {
try {
@@ -223,6 +234,11 @@ async function summarizeExtra(hashedMessages) {
return hashedMessages;
}
/**
* Summarizes messages using the main API method.
* @param {HashedMessage[]} hashedMessages Array of hashed messages
* @returns {Promise<HashedMessage[]>} Summarized messages
*/
async function summarizeMain(hashedMessages) {
for (const element of hashedMessages) {
element.text = await generateRaw(element.text, '', false, false, settings.summary_prompt);
@@ -231,12 +247,39 @@ async function summarizeMain(hashedMessages) {
return hashedMessages;
}
/**
* Summarizes messages using WebLLM.
* @param {HashedMessage[]} hashedMessages Array of hashed messages
* @returns {Promise<HashedMessage[]>} Summarized messages
*/
async function summarizeWebLLM(hashedMessages) {
if (!isWebLlmSupported()) {
console.warn('Vectors: WebLLM is not supported');
return hashedMessages;
}
for (const element of hashedMessages) {
const messages = [{ role:'system', content: settings.summary_prompt }, { role:'user', content: element.text }];
element.text = await generateWebLlmChatPrompt(messages);
}
return hashedMessages;
}
/**
* Summarizes messages using the chosen method.
* @param {HashedMessage[]} hashedMessages Array of hashed messages
* @param {string} endpoint Type of endpoint to use
* @returns {Promise<HashedMessage[]>} Summarized messages
*/
async function summarize(hashedMessages, endpoint = 'main') {
switch (endpoint) {
case 'main':
return await summarizeMain(hashedMessages);
case 'extras':
return await summarizeExtra(hashedMessages);
case 'webllm':
return await summarizeWebLLM(hashedMessages);
default:
console.error('Unsupported endpoint', endpoint);
}

View File

@@ -374,10 +374,11 @@
<select id="vectors_summary_source" class="text_pole">
<option value="main" data-i18n="Main API">Main API</option>
<option value="extras" data-i18n="Extras API">Extras API</option>
<option value="webllm" data-i18n="WebLLM Extension">WebLLM Extension</option>
</select>
<label for="vectors_summary_prompt" title="Summary Prompt:">Summary Prompt:</label>
<small data-i18n="Only used when Main API is selected.">Only used when Main API is selected.</small>
<small data-i18n="Only used when Main API or WebLLM Extension is selected.">Only used when Main API or WebLLM Extension is selected.</small>
<textarea id="vectors_summary_prompt" class="text_pole textarea_compact" rows="6" placeholder="This prompt will be sent to AI to request the summary generation."></textarea>
</div>
</div>

View File

@@ -0,0 +1,7 @@
<h3>Enter the Git URL of the extension to install</h3>
<br>
<p><b>Disclaimer:</b> Please be aware that using external extensions can have unintended side effects and may pose
security risks. Always make sure you trust the source before importing an extension. We are not responsible for any
damage caused by third-party extensions.</p>
<br>
<p>Example: <tt> https://github.com/author/extension-name </tt></p>