mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-03-03 03:17:54 +01:00
Merge pull request #2988 from QuantumEntangledAndy/feat/cachedVectorSummaries
Add client side cacheing of vector summaries
This commit is contained in:
commit
e01a243ce5
@ -36,6 +36,7 @@ import { generateWebLlmChatPrompt, isWebLlmSupported } from '../shared.js';
|
|||||||
/**
|
/**
|
||||||
* @typedef {object} HashedMessage
|
* @typedef {object} HashedMessage
|
||||||
* @property {string} text - The hashed message text
|
* @property {string} text - The hashed message text
|
||||||
|
* @property {number} hash - The hash used as the vector key
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const MODULE_NAME = 'vectors';
|
const MODULE_NAME = 'vectors';
|
||||||
@ -96,6 +97,8 @@ const settings = {
|
|||||||
|
|
||||||
const moduleWorker = new ModuleWorkerWrapper(synchronizeChat);
|
const moduleWorker = new ModuleWorkerWrapper(synchronizeChat);
|
||||||
|
|
||||||
|
const cachedSummaries = new Map();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets the Collection ID for a file embedded in the chat.
|
* Gets the Collection ID for a file embedded in the chat.
|
||||||
* @param {string} fileUrl URL of the file
|
* @param {string} fileUrl URL of the file
|
||||||
@ -118,6 +121,10 @@ async function onVectorizeAllClick() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clear all cached summaries to ensure that new ones are created
|
||||||
|
// upon request of a full vectorise
|
||||||
|
cachedSummaries.clear();
|
||||||
|
|
||||||
const batchSize = 5;
|
const batchSize = 5;
|
||||||
const elapsedLog = [];
|
const elapsedLog = [];
|
||||||
let finished = false;
|
let finished = false;
|
||||||
@ -200,11 +207,10 @@ function splitByChunks(items) {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Summarizes messages using the Extras API method.
|
* Summarizes messages using the Extras API method.
|
||||||
* @param {HashedMessage[]} hashedMessages Array of hashed messages
|
* @param {HashedMessage} element hashed message
|
||||||
* @returns {Promise<HashedMessage[]>} Summarized messages
|
* @returns {Promise<boolean>} Sucess
|
||||||
*/
|
*/
|
||||||
async function summarizeExtra(hashedMessages) {
|
async function summarizeExtra(element) {
|
||||||
for (const element of hashedMessages) {
|
|
||||||
try {
|
try {
|
||||||
const url = new URL(getApiUrl());
|
const url = new URL(getApiUrl());
|
||||||
url.pathname = '/api/summarize';
|
url.pathname = '/api/summarize';
|
||||||
@ -228,42 +234,37 @@ async function summarizeExtra(hashedMessages) {
|
|||||||
}
|
}
|
||||||
catch (error) {
|
catch (error) {
|
||||||
console.log(error);
|
console.log(error);
|
||||||
}
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
return hashedMessages;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Summarizes messages using the main API method.
|
* Summarizes messages using the main API method.
|
||||||
* @param {HashedMessage[]} hashedMessages Array of hashed messages
|
* @param {HashedMessage} element hashed message
|
||||||
* @returns {Promise<HashedMessage[]>} Summarized messages
|
* @returns {Promise<boolean>} Sucess
|
||||||
*/
|
*/
|
||||||
async function summarizeMain(hashedMessages) {
|
async function summarizeMain(element) {
|
||||||
for (const element of hashedMessages) {
|
|
||||||
element.text = await generateRaw(element.text, '', false, false, settings.summary_prompt);
|
element.text = await generateRaw(element.text, '', false, false, settings.summary_prompt);
|
||||||
}
|
return true;
|
||||||
|
|
||||||
return hashedMessages;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Summarizes messages using WebLLM.
|
* Summarizes messages using WebLLM.
|
||||||
* @param {HashedMessage[]} hashedMessages Array of hashed messages
|
* @param {HashedMessage} element hashed message
|
||||||
* @returns {Promise<HashedMessage[]>} Summarized messages
|
* @returns {Promise<boolean>} Sucess
|
||||||
*/
|
*/
|
||||||
async function summarizeWebLLM(hashedMessages) {
|
async function summarizeWebLLM(element) {
|
||||||
if (!isWebLlmSupported()) {
|
if (!isWebLlmSupported()) {
|
||||||
console.warn('Vectors: WebLLM is not supported');
|
console.warn('Vectors: WebLLM is not supported');
|
||||||
return hashedMessages;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const element of hashedMessages) {
|
|
||||||
const messages = [{ role: 'system', content: settings.summary_prompt }, { role: 'user', content: element.text }];
|
const messages = [{ role: 'system', content: settings.summary_prompt }, { role: 'user', content: element.text }];
|
||||||
element.text = await generateWebLlmChatPrompt(messages);
|
element.text = await generateWebLlmChatPrompt(messages);
|
||||||
}
|
|
||||||
|
|
||||||
return hashedMessages;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -273,16 +274,35 @@ async function summarizeWebLLM(hashedMessages) {
|
|||||||
* @returns {Promise<HashedMessage[]>} Summarized messages
|
* @returns {Promise<HashedMessage[]>} Summarized messages
|
||||||
*/
|
*/
|
||||||
async function summarize(hashedMessages, endpoint = 'main') {
|
async function summarize(hashedMessages, endpoint = 'main') {
|
||||||
|
for (const element of hashedMessages) {
|
||||||
|
const cachedSummary = cachedSummaries.get(element.hash);
|
||||||
|
if (!cachedSummary) {
|
||||||
|
let success = true;
|
||||||
switch (endpoint) {
|
switch (endpoint) {
|
||||||
case 'main':
|
case 'main':
|
||||||
return await summarizeMain(hashedMessages);
|
success = await summarizeMain(element);
|
||||||
|
break;
|
||||||
case 'extras':
|
case 'extras':
|
||||||
return await summarizeExtra(hashedMessages);
|
success = await summarizeExtra(element);
|
||||||
|
break;
|
||||||
case 'webllm':
|
case 'webllm':
|
||||||
return await summarizeWebLLM(hashedMessages);
|
success = await summarizeWebLLM(element);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
console.error('Unsupported endpoint', endpoint);
|
console.error('Unsupported endpoint', endpoint);
|
||||||
|
success = false;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
if (success) {
|
||||||
|
cachedSummaries.set(element.hash, element.text);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
element.text = cachedSummary;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return hashedMessages;
|
||||||
}
|
}
|
||||||
|
|
||||||
async function synchronizeChat(batchSize = 5) {
|
async function synchronizeChat(batchSize = 5) {
|
||||||
@ -307,16 +327,15 @@ async function synchronizeChat(batchSize = 5) {
|
|||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
let hashedMessages = context.chat.filter(x => !x.is_system).map(x => ({ text: String(substituteParams(x.mes)), hash: getStringHash(substituteParams(x.mes)), index: context.chat.indexOf(x) }));
|
const hashedMessages = context.chat.filter(x => !x.is_system).map(x => ({ text: String(substituteParams(x.mes)), hash: getStringHash(substituteParams(x.mes)), index: context.chat.indexOf(x) }));
|
||||||
const hashesInCollection = await getSavedHashes(chatId);
|
const hashesInCollection = await getSavedHashes(chatId);
|
||||||
|
|
||||||
if (settings.summarize) {
|
let newVectorItems = hashedMessages.filter(x => !hashesInCollection.includes(x.hash));
|
||||||
hashedMessages = await summarize(hashedMessages, settings.summary_source);
|
|
||||||
}
|
|
||||||
|
|
||||||
const newVectorItems = hashedMessages.filter(x => !hashesInCollection.includes(x.hash));
|
|
||||||
const deletedHashes = hashesInCollection.filter(x => !hashedMessages.some(y => y.hash === x));
|
const deletedHashes = hashesInCollection.filter(x => !hashedMessages.some(y => y.hash === x));
|
||||||
|
|
||||||
|
if (settings.summarize) {
|
||||||
|
newVectorItems = await summarize(newVectorItems, settings.summary_source);
|
||||||
|
}
|
||||||
|
|
||||||
if (newVectorItems.length > 0) {
|
if (newVectorItems.length > 0) {
|
||||||
const chunkedBatch = splitByChunks(newVectorItems.slice(0, batchSize));
|
const chunkedBatch = splitByChunks(newVectorItems.slice(0, batchSize));
|
||||||
@ -687,25 +706,17 @@ const onChatEvent = debounce(async () => await moduleWorker.update(), debounce_t
|
|||||||
* @returns {Promise<string>} Text to query
|
* @returns {Promise<string>} Text to query
|
||||||
*/
|
*/
|
||||||
async function getQueryText(chat, initiator) {
|
async function getQueryText(chat, initiator) {
|
||||||
let queryText = '';
|
let hashedMessages = chat
|
||||||
let i = 0;
|
.map(x => ({ text: String(substituteParams(x.mes)), hash: getStringHash(substituteParams(x.mes)) }))
|
||||||
|
.filter(x => x.text)
|
||||||
let hashedMessages = chat.map(x => ({ text: String(substituteParams(x.mes)) }));
|
.reverse()
|
||||||
|
.slice(0, settings.query);
|
||||||
|
|
||||||
if (initiator === 'chat' && settings.enabled_chats && settings.summarize && settings.summarize_sent) {
|
if (initiator === 'chat' && settings.enabled_chats && settings.summarize && settings.summarize_sent) {
|
||||||
hashedMessages = await summarize(hashedMessages, settings.summary_source);
|
hashedMessages = await summarize(hashedMessages, settings.summary_source);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const message of hashedMessages.slice().reverse()) {
|
const queryText = hashedMessages.map(x => x.text).join('\n');
|
||||||
if (message.text) {
|
|
||||||
queryText += message.text + '\n';
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (i === settings.query) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return collapseNewlines(queryText).trim();
|
return collapseNewlines(queryText).trim();
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user