mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2024-12-12 09:26:33 +01:00
Add WebLLM extension summarization
This commit is contained in:
parent
77ab694ea0
commit
8685c2f471
@ -25,6 +25,7 @@ import { SlashCommandParser } from '../../slash-commands/SlashCommandParser.js';
|
||||
import { SlashCommand } from '../../slash-commands/SlashCommand.js';
|
||||
import { ARGUMENT_TYPE, SlashCommandArgument, SlashCommandNamedArgument } from '../../slash-commands/SlashCommandArgument.js';
|
||||
import { MacrosParser } from '../../macros.js';
|
||||
import { countWebLlmTokens, generateWebLlmChatPrompt, getWebLlmContextSize, isWebLlmSupported } from '../shared.js';
|
||||
export { MODULE_NAME };
|
||||
|
||||
const MODULE_NAME = '1_memory';
|
||||
@ -36,6 +37,40 @@ let lastMessageHash = null;
|
||||
let lastMessageId = null;
|
||||
let inApiCall = false;
|
||||
|
||||
/**
|
||||
* Count the number of tokens in the provided text.
|
||||
* @param {string} text Text to count tokens for
|
||||
* @returns {Promise<number>} Number of tokens in the text
|
||||
*/
|
||||
async function countSourceTokens(text, padding = 0) {
|
||||
if (extension_settings.memory.source === summary_sources.webllm) {
|
||||
const count = await countWebLlmTokens(text);
|
||||
return count + padding;
|
||||
}
|
||||
|
||||
if (extension_settings.memory.source === summary_sources.extras) {
|
||||
const count = getTextTokens(tokenizers.GPT2, text).length;
|
||||
return count + padding;
|
||||
}
|
||||
|
||||
return await getTokenCountAsync(text, padding);
|
||||
}
|
||||
|
||||
async function getSourceContextSize() {
|
||||
const overrideLength = extension_settings.memory.overrideResponseLength;
|
||||
|
||||
if (extension_settings.memory.source === summary_sources.webllm) {
|
||||
const maxContext = await getWebLlmContextSize();
|
||||
return overrideLength > 0 ? (maxContext - overrideLength) : Math.round(maxContext * 0.75);
|
||||
}
|
||||
|
||||
if (extension_settings.source === summary_sources.extras) {
|
||||
return 1024;
|
||||
}
|
||||
|
||||
return getMaxContextSize(overrideLength);
|
||||
}
|
||||
|
||||
const formatMemoryValue = function (value) {
|
||||
if (!value) {
|
||||
return '';
|
||||
@ -55,6 +90,7 @@ const saveChatDebounced = debounce(() => getContext().saveChat(), debounce_timeo
|
||||
const summary_sources = {
|
||||
'extras': 'extras',
|
||||
'main': 'main',
|
||||
'webllm': 'webllm',
|
||||
};
|
||||
|
||||
const prompt_builders = {
|
||||
@ -130,12 +166,12 @@ function loadSettings() {
|
||||
|
||||
async function onPromptForceWordsAutoClick() {
|
||||
const context = getContext();
|
||||
const maxPromptLength = getMaxContextSize(extension_settings.memory.overrideResponseLength);
|
||||
const maxPromptLength = await getSourceContextSize();
|
||||
const chat = context.chat;
|
||||
const allMessages = chat.filter(m => !m.is_system && m.mes).map(m => m.mes);
|
||||
const messagesWordCount = allMessages.map(m => extractAllWords(m)).flat().length;
|
||||
const averageMessageWordCount = messagesWordCount / allMessages.length;
|
||||
const tokensPerWord = await getTokenCountAsync(allMessages.join('\n')) / messagesWordCount;
|
||||
const tokensPerWord = await countSourceTokens(allMessages.join('\n')) / messagesWordCount;
|
||||
const wordsPerToken = 1 / tokensPerWord;
|
||||
const maxPromptLengthWords = Math.round(maxPromptLength * wordsPerToken);
|
||||
// How many words should pass so that messages will start be dropped out of context;
|
||||
@ -168,15 +204,15 @@ async function onPromptForceWordsAutoClick() {
|
||||
|
||||
async function onPromptIntervalAutoClick() {
|
||||
const context = getContext();
|
||||
const maxPromptLength = getMaxContextSize(extension_settings.memory.overrideResponseLength);
|
||||
const maxPromptLength = await getSourceContextSize();
|
||||
const chat = context.chat;
|
||||
const allMessages = chat.filter(m => !m.is_system && m.mes).map(m => m.mes);
|
||||
const messagesWordCount = allMessages.map(m => extractAllWords(m)).flat().length;
|
||||
const messagesTokenCount = await getTokenCountAsync(allMessages.join('\n'));
|
||||
const messagesTokenCount = await countSourceTokens(allMessages.join('\n'));
|
||||
const tokensPerWord = messagesTokenCount / messagesWordCount;
|
||||
const averageMessageTokenCount = messagesTokenCount / allMessages.length;
|
||||
const targetSummaryTokens = Math.round(extension_settings.memory.promptWords * tokensPerWord);
|
||||
const promptTokens = await getTokenCountAsync(extension_settings.memory.prompt);
|
||||
const promptTokens = await countSourceTokens(extension_settings.memory.prompt);
|
||||
const promptAllowance = maxPromptLength - promptTokens - targetSummaryTokens;
|
||||
const maxMessagesPerSummary = extension_settings.memory.maxMessagesPerRequest || 0;
|
||||
const averageMessagesPerPrompt = Math.floor(promptAllowance / averageMessageTokenCount);
|
||||
@ -213,8 +249,8 @@ function onSummarySourceChange(event) {
|
||||
|
||||
function switchSourceControls(value) {
|
||||
$('#memory_settings [data-summary-source]').each((_, element) => {
|
||||
const source = $(element).data('summary-source');
|
||||
$(element).toggle(source === value);
|
||||
const source = element.dataset.summarySource.split(',').map(s => s.trim());
|
||||
$(element).toggle(source.includes(value));
|
||||
});
|
||||
}
|
||||
|
||||
@ -359,6 +395,12 @@ async function onChatEvent() {
|
||||
}
|
||||
}
|
||||
|
||||
if (extension_settings.memory.source === summary_sources.webllm) {
|
||||
if (!isWebLlmSupported()) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const context = getContext();
|
||||
const chat = context.chat;
|
||||
|
||||
@ -431,8 +473,12 @@ async function forceSummarizeChat() {
|
||||
return '';
|
||||
}
|
||||
|
||||
toastr.info('Summarizing chat...', 'Please wait');
|
||||
const value = await summarizeChatMain(context, true, skipWIAN);
|
||||
const toast = toastr.info('Summarizing chat...', 'Please wait', { timeOut: 0, extendedTimeOut: 0 });
|
||||
const value = extension_settings.memory.source === summary_sources.main
|
||||
? await summarizeChatMain(context, true, skipWIAN)
|
||||
: await summarizeChatWebLLM(context, true);
|
||||
|
||||
toastr.clear(toast);
|
||||
|
||||
if (!value) {
|
||||
toastr.warning('Failed to summarize chat');
|
||||
@ -484,16 +530,25 @@ async function summarizeChat(context) {
|
||||
case summary_sources.main:
|
||||
await summarizeChatMain(context, false, skipWIAN);
|
||||
break;
|
||||
case summary_sources.webllm:
|
||||
await summarizeChatWebLLM(context, false);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
async function summarizeChatMain(context, force, skipWIAN) {
|
||||
|
||||
/**
|
||||
* Check if the chat should be summarized based on the current conditions.
|
||||
* Return summary prompt if it should be summarized.
|
||||
* @param {any} context ST context
|
||||
* @param {boolean} force Summarize the chat regardless of the conditions
|
||||
* @returns {Promise<string>} Summary prompt or empty string
|
||||
*/
|
||||
async function getSummaryPromptForNow(context, force) {
|
||||
if (extension_settings.memory.promptInterval === 0 && !force) {
|
||||
console.debug('Prompt interval is set to 0, skipping summarization');
|
||||
return;
|
||||
return '';
|
||||
}
|
||||
|
||||
try {
|
||||
@ -505,17 +560,17 @@ async function summarizeChatMain(context, force, skipWIAN) {
|
||||
waitUntilCondition(() => is_send_press === false, 30000, 100);
|
||||
} catch {
|
||||
console.debug('Timeout waiting for is_send_press');
|
||||
return;
|
||||
return '';
|
||||
}
|
||||
|
||||
if (!context.chat.length) {
|
||||
console.debug('No messages in chat to summarize');
|
||||
return;
|
||||
return '';
|
||||
}
|
||||
|
||||
if (context.chat.length < extension_settings.memory.promptInterval && !force) {
|
||||
console.debug(`Not enough messages in chat to summarize (chat: ${context.chat.length}, interval: ${extension_settings.memory.promptInterval})`);
|
||||
return;
|
||||
return '';
|
||||
}
|
||||
|
||||
let messagesSinceLastSummary = 0;
|
||||
@ -539,7 +594,7 @@ async function summarizeChatMain(context, force, skipWIAN) {
|
||||
|
||||
if (!conditionSatisfied && !force) {
|
||||
console.debug(`Summary conditions not satisfied (messages: ${messagesSinceLastSummary}, interval: ${extension_settings.memory.promptInterval}, words: ${wordsSinceLastSummary}, force words: ${extension_settings.memory.promptForceWords})`);
|
||||
return;
|
||||
return '';
|
||||
}
|
||||
|
||||
console.log('Summarizing chat, messages since last summary: ' + messagesSinceLastSummary, 'words since last summary: ' + wordsSinceLastSummary);
|
||||
@ -547,6 +602,63 @@ async function summarizeChatMain(context, force, skipWIAN) {
|
||||
|
||||
if (!prompt) {
|
||||
console.debug('Summarization prompt is empty. Skipping summarization.');
|
||||
return '';
|
||||
}
|
||||
|
||||
return prompt;
|
||||
}
|
||||
|
||||
async function summarizeChatWebLLM(context, force) {
|
||||
if (!isWebLlmSupported()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const prompt = await getSummaryPromptForNow(context, force);
|
||||
|
||||
if (!prompt) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { rawPrompt, lastUsedIndex } = await getRawSummaryPrompt(context, prompt);
|
||||
|
||||
if (lastUsedIndex === null || lastUsedIndex === -1) {
|
||||
if (force) {
|
||||
toastr.info('To try again, remove the latest summary.', 'No messages found to summarize');
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
const messages = [
|
||||
{ role: 'system', content: prompt },
|
||||
{ role: 'user', content: rawPrompt },
|
||||
];
|
||||
|
||||
const params = {};
|
||||
|
||||
if (extension_settings.memory.overrideResponseLength > 0) {
|
||||
params.max_tokens = extension_settings.memory.overrideResponseLength;
|
||||
}
|
||||
|
||||
const summary = await generateWebLlmChatPrompt(messages, params);
|
||||
const newContext = getContext();
|
||||
|
||||
// something changed during summarization request
|
||||
if (newContext.groupId !== context.groupId ||
|
||||
newContext.chatId !== context.chatId ||
|
||||
(!newContext.groupId && (newContext.characterId !== context.characterId))) {
|
||||
console.log('Context changed, summary discarded');
|
||||
return;
|
||||
}
|
||||
|
||||
setMemoryContext(summary, true, lastUsedIndex);
|
||||
return summary;
|
||||
}
|
||||
|
||||
async function summarizeChatMain(context, force, skipWIAN) {
|
||||
const prompt = await getSummaryPromptForNow(context, force);
|
||||
|
||||
if (!prompt) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -634,7 +746,7 @@ async function getRawSummaryPrompt(context, prompt) {
|
||||
chat.pop(); // We always exclude the last message from the buffer
|
||||
const chatBuffer = [];
|
||||
const PADDING = 64;
|
||||
const PROMPT_SIZE = getMaxContextSize(extension_settings.memory.overrideResponseLength);
|
||||
const PROMPT_SIZE = await getSourceContextSize();
|
||||
let latestUsedMessage = null;
|
||||
|
||||
for (let index = latestSummaryIndex + 1; index < chat.length; index++) {
|
||||
@ -651,7 +763,7 @@ async function getRawSummaryPrompt(context, prompt) {
|
||||
const entry = `${message.name}:\n${message.mes}`;
|
||||
chatBuffer.push(entry);
|
||||
|
||||
const tokens = await getTokenCountAsync(getMemoryString(true), PADDING);
|
||||
const tokens = await countSourceTokens(getMemoryString(true), PADDING);
|
||||
|
||||
if (tokens > PROMPT_SIZE) {
|
||||
chatBuffer.pop();
|
||||
|
@ -13,6 +13,7 @@
|
||||
<select id="summary_source">
|
||||
<option value="main" data-i18n="ext_sum_main_api">Main API</option>
|
||||
<option value="extras">Extras API</option>
|
||||
<option value="webllm" data-i18n="ext_sum_webllm">WebLLM Extension</option>
|
||||
</select><br>
|
||||
|
||||
<div class="flex-container justifyspacebetween alignitemscenter">
|
||||
@ -24,7 +25,7 @@
|
||||
|
||||
<textarea id="memory_contents" class="text_pole textarea_compact" rows="6" data-i18n="[placeholder]ext_sum_memory_placeholder" placeholder="Summary will be generated here..."></textarea>
|
||||
<div class="memory_contents_controls">
|
||||
<div id="memory_force_summarize" data-summary-source="main" class="menu_button menu_button_icon" title="Trigger a summary update right now." data-i18n="[title]ext_sum_force_tip">
|
||||
<div id="memory_force_summarize" data-summary-source="main,webllm" class="menu_button menu_button_icon" title="Trigger a summary update right now." data-i18n="[title]ext_sum_force_tip">
|
||||
<i class="fa-solid fa-database"></i>
|
||||
<span data-i18n="ext_sum_force_text">Summarize now</span>
|
||||
</div>
|
||||
@ -58,7 +59,7 @@
|
||||
<span data-i18n="ext_sum_prompt_builder_3">Classic, blocking</span>
|
||||
</label>
|
||||
</div>
|
||||
<div data-summary-source="main">
|
||||
<div data-summary-source="main,webllm">
|
||||
<label for="memory_prompt" class="title_restorable">
|
||||
<span data-i18n="Summary Prompt">Summary Prompt</span>
|
||||
<div id="memory_prompt_restore" data-i18n="[title]ext_sum_restore_default_prompt_tip" title="Restore default prompt" class="right_menu_button">
|
||||
@ -74,7 +75,7 @@
|
||||
</label>
|
||||
<input id="memory_override_response_length" type="range" value="{{defaultSettings.overrideResponseLength}}" min="{{defaultSettings.overrideResponseLengthMin}}" max="{{defaultSettings.overrideResponseLengthMax}}" step="{{defaultSettings.overrideResponseLengthStep}}" />
|
||||
<label for="memory_max_messages_per_request">
|
||||
<span data-i18n="ext_sum_raw_max_msg">[Raw] Max messages per request</span> (<span id="memory_max_messages_per_request_value"></span>)
|
||||
<span data-i18n="ext_sum_raw_max_msg">[Raw/WebLLM] Max messages per request</span> (<span id="memory_max_messages_per_request_value"></span>)
|
||||
<small class="memory_disabled_hint" data-i18n="ext_sum_0_unlimited">0 = unlimited</small>
|
||||
</label>
|
||||
<input id="memory_max_messages_per_request" type="range" value="{{defaultSettings.maxMessagesPerRequest}}" min="{{defaultSettings.maxMessagesPerRequestMin}}" max="{{defaultSettings.maxMessagesPerRequestMax}}" step="{{defaultSettings.maxMessagesPerRequestStep}}" />
|
||||
|
@ -183,32 +183,40 @@ function throwIfInvalidModel(useReverseProxy) {
|
||||
*/
|
||||
export function isWebLlmSupported() {
|
||||
if (!('gpu' in navigator)) {
|
||||
toastr.error('Your browser does not support the WebGPU API. Please use a different browser.', 'WebLLM', {
|
||||
preventDuplicates: true,
|
||||
timeOut: 0,
|
||||
extendedTimeOut: 0,
|
||||
});
|
||||
const warningKey = 'webllm_browser_warning_shown';
|
||||
if (!sessionStorage.getItem(warningKey)) {
|
||||
toastr.error('Your browser does not support the WebGPU API. Please use a different browser.', 'WebLLM', {
|
||||
preventDuplicates: true,
|
||||
timeOut: 0,
|
||||
extendedTimeOut: 0,
|
||||
});
|
||||
sessionStorage.setItem(warningKey, '1');
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!('llm' in SillyTavern)) {
|
||||
toastr.error('WebLLM extension is not installed. Click here to install it.', 'WebLLM', {
|
||||
timeOut: 0,
|
||||
extendedTimeOut: 0,
|
||||
preventDuplicates: true,
|
||||
onclick: () => {
|
||||
const button = document.getElementById('third_party_extension_button');
|
||||
if (button) {
|
||||
button.click();
|
||||
}
|
||||
const warningKey = 'webllm_extension_warning_shown';
|
||||
if (!sessionStorage.getItem(warningKey)) {
|
||||
toastr.error('WebLLM extension is not installed. Click here to install it.', 'WebLLM', {
|
||||
timeOut: 0,
|
||||
extendedTimeOut: 0,
|
||||
preventDuplicates: true,
|
||||
onclick: () => {
|
||||
const button = document.getElementById('third_party_extension_button');
|
||||
if (button) {
|
||||
button.click();
|
||||
}
|
||||
|
||||
const input = document.querySelector('dialog textarea');
|
||||
const input = document.querySelector('dialog textarea');
|
||||
|
||||
if (input instanceof HTMLTextAreaElement) {
|
||||
input.value = 'https://github.com/SillyTavern/Extension-WebLLM';
|
||||
}
|
||||
},
|
||||
});
|
||||
if (input instanceof HTMLTextAreaElement) {
|
||||
input.value = 'https://github.com/SillyTavern/Extension-WebLLM';
|
||||
}
|
||||
},
|
||||
});
|
||||
sessionStorage.setItem(warningKey, '1');
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -218,15 +226,16 @@ export function isWebLlmSupported() {
|
||||
/**
|
||||
* Generates text in response to a chat prompt using WebLLM.
|
||||
* @param {any[]} messages Messages to use for generating
|
||||
* @param {object} params Additional parameters
|
||||
* @returns {Promise<string>} Generated response
|
||||
*/
|
||||
export async function generateWebLlmChatPrompt(messages) {
|
||||
export async function generateWebLlmChatPrompt(messages, params = {}) {
|
||||
if (!isWebLlmSupported()) {
|
||||
throw new Error('WebLLM extension is not installed.');
|
||||
}
|
||||
|
||||
const engine = SillyTavern.llm;
|
||||
const response = await engine.generateChatPrompt(messages);
|
||||
const response = await engine.generateChatPrompt(messages, params);
|
||||
return response;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user