diff --git a/public/scripts/extensions/vectors/index.js b/public/scripts/extensions/vectors/index.js index c78004917..c375dc0ef 100644 --- a/public/scripts/extensions/vectors/index.js +++ b/public/scripts/extensions/vectors/index.js @@ -36,6 +36,7 @@ import { generateWebLlmChatPrompt, isWebLlmSupported } from '../shared.js'; /** * @typedef {object} HashedMessage * @property {string} text - The hashed message text + * @property {number} hash - The hash used as the vector key */ const MODULE_NAME = 'vectors'; @@ -96,6 +97,8 @@ const settings = { const moduleWorker = new ModuleWorkerWrapper(synchronizeChat); +const cachedSummaries = new Map(); + /** * Gets the Collection ID for a file embedded in the chat. * @param {string} fileUrl URL of the file @@ -118,6 +121,10 @@ async function onVectorizeAllClick() { return; } + // Clear all cached summaries to ensure that new ones are created + // upon request of a full vectorise + cachedSummaries.clear(); + const batchSize = 5; const elapsedLog = []; let finished = false; @@ -200,70 +207,64 @@ function splitByChunks(items) { /** * Summarizes messages using the Extras API method. - * @param {HashedMessage[]} hashedMessages Array of hashed messages - * @returns {Promise} Summarized messages + * @param {HashedMessage} element hashed message + * @returns {Promise} Sucess */ -async function summarizeExtra(hashedMessages) { - for (const element of hashedMessages) { - try { - const url = new URL(getApiUrl()); - url.pathname = '/api/summarize'; +async function summarizeExtra(element) { + try { + const url = new URL(getApiUrl()); + url.pathname = '/api/summarize'; - const apiResult = await doExtrasFetch(url, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'Bypass-Tunnel-Reminder': 'bypass', - }, - body: JSON.stringify({ - text: element.text, - params: {}, - }), - }); + const apiResult = await doExtrasFetch(url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Bypass-Tunnel-Reminder': 'bypass', + }, + body: JSON.stringify({ + text: element.text, + params: {}, + }), + }); - if (apiResult.ok) { - const data = await apiResult.json(); - element.text = data.summary; - } - } - catch (error) { - console.log(error); + if (apiResult.ok) { + const data = await apiResult.json(); + element.text = data.summary; } } + catch (error) { + console.log(error); + return false; + } - return hashedMessages; + return true; } /** * Summarizes messages using the main API method. - * @param {HashedMessage[]} hashedMessages Array of hashed messages - * @returns {Promise} Summarized messages + * @param {HashedMessage} element hashed message + * @returns {Promise} Sucess */ -async function summarizeMain(hashedMessages) { - for (const element of hashedMessages) { - element.text = await generateRaw(element.text, '', false, false, settings.summary_prompt); - } - - return hashedMessages; +async function summarizeMain(element) { + element.text = await generateRaw(element.text, '', false, false, settings.summary_prompt); + return true; } /** * Summarizes messages using WebLLM. - * @param {HashedMessage[]} hashedMessages Array of hashed messages - * @returns {Promise} Summarized messages + * @param {HashedMessage} element hashed message + * @returns {Promise} Sucess */ -async function summarizeWebLLM(hashedMessages) { +async function summarizeWebLLM(element) { if (!isWebLlmSupported()) { console.warn('Vectors: WebLLM is not supported'); - return hashedMessages; + return false; } - for (const element of hashedMessages) { - const messages = [{ role:'system', content: settings.summary_prompt }, { role:'user', content: element.text }]; - element.text = await generateWebLlmChatPrompt(messages); - } + const messages = [{ role:'system', content: settings.summary_prompt }, { role:'user', content: element.text }]; + element.text = await generateWebLlmChatPrompt(messages); - return hashedMessages; + return true; } /** @@ -273,16 +274,35 @@ async function summarizeWebLLM(hashedMessages) { * @returns {Promise} Summarized messages */ async function summarize(hashedMessages, endpoint = 'main') { - switch (endpoint) { - case 'main': - return await summarizeMain(hashedMessages); - case 'extras': - return await summarizeExtra(hashedMessages); - case 'webllm': - return await summarizeWebLLM(hashedMessages); - default: - console.error('Unsupported endpoint', endpoint); + for (const element of hashedMessages) { + const cachedSummary = cachedSummaries.get(element.hash) + if (!cachedSummary) { + let sucess = true; + switch (endpoint) { + case 'main': + sucess = await summarizeMain(element); + break; + case 'extras': + sucess = await summarizeExtra(element); + break; + case 'webllm': + sucess = await summarizeWebLLM(element); + break; + default: + console.error('Unsupported endpoint', endpoint); + sucess = false; + break; + } + if (sucess) { + cachedSummaries.set(element.hash, element.text); + } else { + break; + } + } else { + element.text = cachedSummary; + } } + return hashedMessages; } async function synchronizeChat(batchSize = 5) { @@ -307,16 +327,15 @@ async function synchronizeChat(batchSize = 5) { return -1; } - let hashedMessages = context.chat.filter(x => !x.is_system).map(x => ({ text: String(substituteParams(x.mes)), hash: getStringHash(substituteParams(x.mes)), index: context.chat.indexOf(x) })); + const hashedMessages = context.chat.filter(x => !x.is_system).map(x => ({ text: String(substituteParams(x.mes)), hash: getStringHash(substituteParams(x.mes)), index: context.chat.indexOf(x) })); const hashesInCollection = await getSavedHashes(chatId); - if (settings.summarize) { - hashedMessages = await summarize(hashedMessages, settings.summary_source); - } - - const newVectorItems = hashedMessages.filter(x => !hashesInCollection.includes(x.hash)); + let newVectorItems = hashedMessages.filter(x => !hashesInCollection.includes(x.hash)); const deletedHashes = hashesInCollection.filter(x => !hashedMessages.some(y => y.hash === x)); + if (settings.summarize) { + newVectorItems = await summarize(newVectorItems, settings.summary_source); + } if (newVectorItems.length > 0) { const chunkedBatch = splitByChunks(newVectorItems.slice(0, batchSize));