Merge pull request #2988 from QuantumEntangledAndy/feat/cachedVectorSummaries

Add client side cacheing of vector summaries
This commit is contained in:
Cohee 2024-10-15 23:32:20 +03:00 committed by GitHub
commit e01a243ce5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -36,6 +36,7 @@ import { generateWebLlmChatPrompt, isWebLlmSupported } from '../shared.js';
/** /**
* @typedef {object} HashedMessage * @typedef {object} HashedMessage
* @property {string} text - The hashed message text * @property {string} text - The hashed message text
* @property {number} hash - The hash used as the vector key
*/ */
const MODULE_NAME = 'vectors'; const MODULE_NAME = 'vectors';
@ -96,6 +97,8 @@ const settings = {
const moduleWorker = new ModuleWorkerWrapper(synchronizeChat); const moduleWorker = new ModuleWorkerWrapper(synchronizeChat);
const cachedSummaries = new Map();
/** /**
* Gets the Collection ID for a file embedded in the chat. * Gets the Collection ID for a file embedded in the chat.
* @param {string} fileUrl URL of the file * @param {string} fileUrl URL of the file
@ -118,6 +121,10 @@ async function onVectorizeAllClick() {
return; return;
} }
// Clear all cached summaries to ensure that new ones are created
// upon request of a full vectorise
cachedSummaries.clear();
const batchSize = 5; const batchSize = 5;
const elapsedLog = []; const elapsedLog = [];
let finished = false; let finished = false;
@ -200,11 +207,10 @@ function splitByChunks(items) {
/** /**
* Summarizes messages using the Extras API method. * Summarizes messages using the Extras API method.
* @param {HashedMessage[]} hashedMessages Array of hashed messages * @param {HashedMessage} element hashed message
* @returns {Promise<HashedMessage[]>} Summarized messages * @returns {Promise<boolean>} Sucess
*/ */
async function summarizeExtra(hashedMessages) { async function summarizeExtra(element) {
for (const element of hashedMessages) {
try { try {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/summarize'; url.pathname = '/api/summarize';
@ -228,42 +234,37 @@ async function summarizeExtra(hashedMessages) {
} }
catch (error) { catch (error) {
console.log(error); console.log(error);
} return false;
} }
return hashedMessages; return true;
} }
/** /**
* Summarizes messages using the main API method. * Summarizes messages using the main API method.
* @param {HashedMessage[]} hashedMessages Array of hashed messages * @param {HashedMessage} element hashed message
* @returns {Promise<HashedMessage[]>} Summarized messages * @returns {Promise<boolean>} Sucess
*/ */
async function summarizeMain(hashedMessages) { async function summarizeMain(element) {
for (const element of hashedMessages) {
element.text = await generateRaw(element.text, '', false, false, settings.summary_prompt); element.text = await generateRaw(element.text, '', false, false, settings.summary_prompt);
} return true;
return hashedMessages;
} }
/** /**
* Summarizes messages using WebLLM. * Summarizes messages using WebLLM.
* @param {HashedMessage[]} hashedMessages Array of hashed messages * @param {HashedMessage} element hashed message
* @returns {Promise<HashedMessage[]>} Summarized messages * @returns {Promise<boolean>} Sucess
*/ */
async function summarizeWebLLM(hashedMessages) { async function summarizeWebLLM(element) {
if (!isWebLlmSupported()) { if (!isWebLlmSupported()) {
console.warn('Vectors: WebLLM is not supported'); console.warn('Vectors: WebLLM is not supported');
return hashedMessages; return false;
} }
for (const element of hashedMessages) {
const messages = [{ role: 'system', content: settings.summary_prompt }, { role: 'user', content: element.text }]; const messages = [{ role: 'system', content: settings.summary_prompt }, { role: 'user', content: element.text }];
element.text = await generateWebLlmChatPrompt(messages); element.text = await generateWebLlmChatPrompt(messages);
}
return hashedMessages; return true;
} }
/** /**
@ -273,16 +274,35 @@ async function summarizeWebLLM(hashedMessages) {
* @returns {Promise<HashedMessage[]>} Summarized messages * @returns {Promise<HashedMessage[]>} Summarized messages
*/ */
async function summarize(hashedMessages, endpoint = 'main') { async function summarize(hashedMessages, endpoint = 'main') {
for (const element of hashedMessages) {
const cachedSummary = cachedSummaries.get(element.hash);
if (!cachedSummary) {
let success = true;
switch (endpoint) { switch (endpoint) {
case 'main': case 'main':
return await summarizeMain(hashedMessages); success = await summarizeMain(element);
break;
case 'extras': case 'extras':
return await summarizeExtra(hashedMessages); success = await summarizeExtra(element);
break;
case 'webllm': case 'webllm':
return await summarizeWebLLM(hashedMessages); success = await summarizeWebLLM(element);
break;
default: default:
console.error('Unsupported endpoint', endpoint); console.error('Unsupported endpoint', endpoint);
success = false;
break;
} }
if (success) {
cachedSummaries.set(element.hash, element.text);
} else {
break;
}
} else {
element.text = cachedSummary;
}
}
return hashedMessages;
} }
async function synchronizeChat(batchSize = 5) { async function synchronizeChat(batchSize = 5) {
@ -307,16 +327,15 @@ async function synchronizeChat(batchSize = 5) {
return -1; return -1;
} }
let hashedMessages = context.chat.filter(x => !x.is_system).map(x => ({ text: String(substituteParams(x.mes)), hash: getStringHash(substituteParams(x.mes)), index: context.chat.indexOf(x) })); const hashedMessages = context.chat.filter(x => !x.is_system).map(x => ({ text: String(substituteParams(x.mes)), hash: getStringHash(substituteParams(x.mes)), index: context.chat.indexOf(x) }));
const hashesInCollection = await getSavedHashes(chatId); const hashesInCollection = await getSavedHashes(chatId);
if (settings.summarize) { let newVectorItems = hashedMessages.filter(x => !hashesInCollection.includes(x.hash));
hashedMessages = await summarize(hashedMessages, settings.summary_source);
}
const newVectorItems = hashedMessages.filter(x => !hashesInCollection.includes(x.hash));
const deletedHashes = hashesInCollection.filter(x => !hashedMessages.some(y => y.hash === x)); const deletedHashes = hashesInCollection.filter(x => !hashedMessages.some(y => y.hash === x));
if (settings.summarize) {
newVectorItems = await summarize(newVectorItems, settings.summary_source);
}
if (newVectorItems.length > 0) { if (newVectorItems.length > 0) {
const chunkedBatch = splitByChunks(newVectorItems.slice(0, batchSize)); const chunkedBatch = splitByChunks(newVectorItems.slice(0, batchSize));
@ -687,25 +706,17 @@ const onChatEvent = debounce(async () => await moduleWorker.update(), debounce_t
* @returns {Promise<string>} Text to query * @returns {Promise<string>} Text to query
*/ */
async function getQueryText(chat, initiator) { async function getQueryText(chat, initiator) {
let queryText = ''; let hashedMessages = chat
let i = 0; .map(x => ({ text: String(substituteParams(x.mes)), hash: getStringHash(substituteParams(x.mes)) }))
.filter(x => x.text)
let hashedMessages = chat.map(x => ({ text: String(substituteParams(x.mes)) })); .reverse()
.slice(0, settings.query);
if (initiator === 'chat' && settings.enabled_chats && settings.summarize && settings.summarize_sent) { if (initiator === 'chat' && settings.enabled_chats && settings.summarize && settings.summarize_sent) {
hashedMessages = await summarize(hashedMessages, settings.summary_source); hashedMessages = await summarize(hashedMessages, settings.summary_source);
} }
for (const message of hashedMessages.slice().reverse()) { const queryText = hashedMessages.map(x => x.text).join('\n');
if (message.text) {
queryText += message.text + '\n';
i++;
}
if (i === settings.query) {
break;
}
}
return collapseNewlines(queryText).trim(); return collapseNewlines(queryText).trim();
} }