diff --git a/public/scripts/extensions/vectors/index.js b/public/scripts/extensions/vectors/index.js index c78004917..eb989556e 100644 --- a/public/scripts/extensions/vectors/index.js +++ b/public/scripts/extensions/vectors/index.js @@ -36,6 +36,7 @@ import { generateWebLlmChatPrompt, isWebLlmSupported } from '../shared.js'; /** * @typedef {object} HashedMessage * @property {string} text - The hashed message text + * @property {number} hash - The hash used as the vector key */ const MODULE_NAME = 'vectors'; @@ -96,6 +97,8 @@ const settings = { const moduleWorker = new ModuleWorkerWrapper(synchronizeChat); +const cachedSummaries = new Map(); + /** * Gets the Collection ID for a file embedded in the chat. * @param {string} fileUrl URL of the file @@ -118,6 +121,10 @@ async function onVectorizeAllClick() { return; } + // Clear all cached summaries to ensure that new ones are created + // upon request of a full vectorise + cachedSummaries.clear(); + const batchSize = 5; const elapsedLog = []; let finished = false; @@ -200,70 +207,64 @@ function splitByChunks(items) { /** * Summarizes messages using the Extras API method. - * @param {HashedMessage[]} hashedMessages Array of hashed messages - * @returns {Promise} Summarized messages + * @param {HashedMessage} element hashed message + * @returns {Promise} Sucess */ -async function summarizeExtra(hashedMessages) { - for (const element of hashedMessages) { - try { - const url = new URL(getApiUrl()); - url.pathname = '/api/summarize'; +async function summarizeExtra(element) { + try { + const url = new URL(getApiUrl()); + url.pathname = '/api/summarize'; - const apiResult = await doExtrasFetch(url, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'Bypass-Tunnel-Reminder': 'bypass', - }, - body: JSON.stringify({ - text: element.text, - params: {}, - }), - }); + const apiResult = await doExtrasFetch(url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Bypass-Tunnel-Reminder': 'bypass', + }, + body: JSON.stringify({ + text: element.text, + params: {}, + }), + }); - if (apiResult.ok) { - const data = await apiResult.json(); - element.text = data.summary; - } - } - catch (error) { - console.log(error); + if (apiResult.ok) { + const data = await apiResult.json(); + element.text = data.summary; } } + catch (error) { + console.log(error); + return false; + } - return hashedMessages; + return true; } /** * Summarizes messages using the main API method. - * @param {HashedMessage[]} hashedMessages Array of hashed messages - * @returns {Promise} Summarized messages + * @param {HashedMessage} element hashed message + * @returns {Promise} Sucess */ -async function summarizeMain(hashedMessages) { - for (const element of hashedMessages) { - element.text = await generateRaw(element.text, '', false, false, settings.summary_prompt); - } - - return hashedMessages; +async function summarizeMain(element) { + element.text = await generateRaw(element.text, '', false, false, settings.summary_prompt); + return true; } /** * Summarizes messages using WebLLM. - * @param {HashedMessage[]} hashedMessages Array of hashed messages - * @returns {Promise} Summarized messages + * @param {HashedMessage} element hashed message + * @returns {Promise} Sucess */ -async function summarizeWebLLM(hashedMessages) { +async function summarizeWebLLM(element) { if (!isWebLlmSupported()) { console.warn('Vectors: WebLLM is not supported'); - return hashedMessages; + return false; } - for (const element of hashedMessages) { - const messages = [{ role:'system', content: settings.summary_prompt }, { role:'user', content: element.text }]; - element.text = await generateWebLlmChatPrompt(messages); - } + const messages = [{ role: 'system', content: settings.summary_prompt }, { role: 'user', content: element.text }]; + element.text = await generateWebLlmChatPrompt(messages); - return hashedMessages; + return true; } /** @@ -273,16 +274,35 @@ async function summarizeWebLLM(hashedMessages) { * @returns {Promise} Summarized messages */ async function summarize(hashedMessages, endpoint = 'main') { - switch (endpoint) { - case 'main': - return await summarizeMain(hashedMessages); - case 'extras': - return await summarizeExtra(hashedMessages); - case 'webllm': - return await summarizeWebLLM(hashedMessages); - default: - console.error('Unsupported endpoint', endpoint); + for (const element of hashedMessages) { + const cachedSummary = cachedSummaries.get(element.hash); + if (!cachedSummary) { + let success = true; + switch (endpoint) { + case 'main': + success = await summarizeMain(element); + break; + case 'extras': + success = await summarizeExtra(element); + break; + case 'webllm': + success = await summarizeWebLLM(element); + break; + default: + console.error('Unsupported endpoint', endpoint); + success = false; + break; + } + if (success) { + cachedSummaries.set(element.hash, element.text); + } else { + break; + } + } else { + element.text = cachedSummary; + } } + return hashedMessages; } async function synchronizeChat(batchSize = 5) { @@ -307,16 +327,15 @@ async function synchronizeChat(batchSize = 5) { return -1; } - let hashedMessages = context.chat.filter(x => !x.is_system).map(x => ({ text: String(substituteParams(x.mes)), hash: getStringHash(substituteParams(x.mes)), index: context.chat.indexOf(x) })); + const hashedMessages = context.chat.filter(x => !x.is_system).map(x => ({ text: String(substituteParams(x.mes)), hash: getStringHash(substituteParams(x.mes)), index: context.chat.indexOf(x) })); const hashesInCollection = await getSavedHashes(chatId); - if (settings.summarize) { - hashedMessages = await summarize(hashedMessages, settings.summary_source); - } - - const newVectorItems = hashedMessages.filter(x => !hashesInCollection.includes(x.hash)); + let newVectorItems = hashedMessages.filter(x => !hashesInCollection.includes(x.hash)); const deletedHashes = hashesInCollection.filter(x => !hashedMessages.some(y => y.hash === x)); + if (settings.summarize) { + newVectorItems = await summarize(newVectorItems, settings.summary_source); + } if (newVectorItems.length > 0) { const chunkedBatch = splitByChunks(newVectorItems.slice(0, batchSize)); @@ -687,25 +706,17 @@ const onChatEvent = debounce(async () => await moduleWorker.update(), debounce_t * @returns {Promise} Text to query */ async function getQueryText(chat, initiator) { - let queryText = ''; - let i = 0; - - let hashedMessages = chat.map(x => ({ text: String(substituteParams(x.mes)) })); + let hashedMessages = chat + .map(x => ({ text: String(substituteParams(x.mes)), hash: getStringHash(substituteParams(x.mes)) })) + .filter(x => x.text) + .reverse() + .slice(0, settings.query); if (initiator === 'chat' && settings.enabled_chats && settings.summarize && settings.summarize_sent) { hashedMessages = await summarize(hashedMessages, settings.summary_source); } - for (const message of hashedMessages.slice().reverse()) { - if (message.text) { - queryText += message.text + '\n'; - i++; - } - - if (i === settings.query) { - break; - } - } + const queryText = hashedMessages.map(x => x.text).join('\n'); return collapseNewlines(queryText).trim(); }