From 19df1f52cd3651338a8e258fc57d725b23309cbb Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Thu, 30 Nov 2023 00:01:59 +0200 Subject: [PATCH] Vector storage file retrieval --- public/script.js | 1 + public/scripts/chats.js | 4 +- public/scripts/extensions/vectors/index.js | 158 ++++++++++++++++-- .../scripts/extensions/vectors/settings.html | 155 +++++++++++------ server.js | 4 +- src/vectors.js | 44 ++--- 6 files changed, 276 insertions(+), 90 deletions(-) diff --git a/public/script.js b/public/script.js index d822c78d8..ea142fa50 100644 --- a/public/script.js +++ b/public/script.js @@ -9062,6 +9062,7 @@ jQuery(async function () { hideStopButton(); } eventSource.emit(event_types.GENERATION_STOPPED); + activateSendButtons(); }); $('.drawer-toggle').on('click', function () { diff --git a/public/scripts/chats.js b/public/scripts/chats.js index 7a806e37b..8627df034 100644 --- a/public/scripts/chats.js +++ b/public/scripts/chats.js @@ -119,7 +119,7 @@ export async function populateFileAttachment(message, inputId = 'file_form_input const fileText = await converter(file); base64Data = window.btoa(unescape(encodeURIComponent(fileText))); } catch (error) { - toastr.error(error, 'Could not convert file'); + toastr.error(String(error), 'Could not convert file'); console.error('Could not convert file', error); } } @@ -169,7 +169,7 @@ export async function uploadFileAttachment(fileName, base64Data) { const responseData = await result.json(); return responseData.path.replace(/\\/g, '/'); } catch (error) { - toastr.error(error, 'Could not upload file'); + toastr.error(String(error), 'Could not upload file'); console.error('Could not upload file', error); } } diff --git a/public/scripts/extensions/vectors/index.js b/public/scripts/extensions/vectors/index.js index 05cda80dc..6c8c63ec3 100644 --- a/public/scripts/extensions/vectors/index.js +++ b/public/scripts/extensions/vectors/index.js @@ -2,28 +2,37 @@ import { eventSource, event_types, extension_prompt_types, getCurrentChatId, get import { ModuleWorkerWrapper, extension_settings, getContext, renderExtensionTemplate } from "../../extensions.js"; import { collapseNewlines, power_user, ui_mode } from "../../power-user.js"; import { SECRET_KEYS, secret_state } from "../../secrets.js"; -import { debounce, getStringHash as calculateHash, waitUntilCondition, onlyUnique } from "../../utils.js"; +import { debounce, getStringHash as calculateHash, waitUntilCondition, onlyUnique, splitRecursive } from "../../utils.js"; const MODULE_NAME = 'vectors'; export const EXTENSION_PROMPT_TAG = '3_vectors'; const settings = { - enabled: false, + // For both source: 'transformers', + + // For chats + enabled_chats: false, template: `Past events: {{text}}`, depth: 2, position: extension_prompt_types.IN_PROMPT, protect: 5, insert: 3, query: 2, + + // For files + enabled_files: false, + size_threshold: 5, + chunk_size: 1000, + chunk_count: 4, }; const moduleWorker = new ModuleWorkerWrapper(synchronizeChat); async function onVectorizeAllClick() { try { - if (!settings.enabled) { + if (!settings.enabled_chats) { return; } @@ -78,7 +87,7 @@ async function onVectorizeAllClick() { let syncBlocked = false; async function synchronizeChat(batchSize = 5) { - if (!settings.enabled) { + if (!settings.enabled_chats) { return -1; } @@ -99,7 +108,7 @@ async function synchronizeChat(batchSize = 5) { return -1; } - const hashedMessages = context.chat.filter(x => !x.is_system).map(x => ({ text: String(x.mes), hash: getStringHash(x.mes) })); + const hashedMessages = context.chat.filter(x => !x.is_system).map(x => ({ text: String(x.mes), hash: getStringHash(x.mes), index: context.chat.indexOf(x) })); const hashesInCollection = await getSavedHashes(chatId); const newVectorItems = hashedMessages.filter(x => !hashesInCollection.includes(x.hash)); @@ -149,6 +158,92 @@ function getStringHash(str) { return hash; } +/** + * Retrieves files from the chat and inserts them into the vector index. + * @param {object[]} chat Array of chat messages + * @returns {Promise} + */ +async function processFiles(chat) { + try { + if (!settings.enabled_files) { + return; + } + + for (const message of chat) { + // Message has no file + if (!message?.extra?.file) { + continue; + } + + // Trim file inserted by the script + const fileText = message.mes.substring(message.extra.fileStart).trim(); + + // Convert kilobytes to string length + const thresholdLength = settings.size_threshold * 1024; + + // File is too small + if (fileText.length < thresholdLength) { + continue; + } + + message.mes = message.mes.substring(0, message.extra.fileStart); + + const fileName = message.extra.file.name; + const collectionId = `file_${getStringHash(fileName)}`; + const hashesInCollection = await getSavedHashes(collectionId); + + // File is already in the collection + if (!hashesInCollection.length) { + await vectorizeFile(fileText, fileName, collectionId); + } + + const queryText = getQueryText(chat); + const fileChunks = await retrieveFileChunks(queryText, collectionId); + + message.mes += '\n\n' + fileChunks; + } + } catch (error) { + console.error('Vectors: Failed to retrieve files', error); + } +} + +/** + * Retrieves file chunks from the vector index and inserts them into the chat. + * @param {string} queryText Text to query + * @param {string} collectionId File collection ID + * @returns {Promise} Retrieved file text + */ +async function retrieveFileChunks(queryText, collectionId) { + console.debug(`Vectors: Retrieving file chunks for collection ${collectionId}`, queryText); + const queryResults = await queryCollection(collectionId, queryText, settings.chunk_count); + console.debug(`Vectors: Retrieved ${queryResults.hashes.length} file chunks for collection ${collectionId}`, queryResults); + const metadata = queryResults.metadata.filter(x => x.text).sort((a, b) => a.index - b.index).map(x => x.text); + const fileText = metadata.join('\n'); + + return fileText; +} + +/** + * Vectorizes a file and inserts it into the vector index. + * @param {string} fileText File text + * @param {string} fileName File name + * @param {string} collectionId File collection ID + */ +async function vectorizeFile(fileText, fileName, collectionId) { + try { + toastr.info("Vectorization may take some time, please wait...", `Ingesting file ${fileName}`); + const chunks = splitRecursive(fileText, settings.chunk_size); + console.debug(`Vectors: Split file ${fileName} into ${chunks.length} chunks`, chunks); + + const items = chunks.map((chunk, index) => ({ hash: getStringHash(chunk), text: chunk, index: index })); + await insertVectorItems(collectionId, items); + + console.log(`Vectors: Inserted ${chunks.length} vector items for file ${fileName} into ${collectionId}`); + } catch (error) { + console.error('Vectors: Failed to vectorize file', error); + } +} + /** * Removes the most relevant messages from the chat and displays them in the extension prompt * @param {object[]} chat Array of chat messages @@ -158,7 +253,11 @@ async function rearrangeChat(chat) { // Clear the extension prompt setExtensionPrompt(EXTENSION_PROMPT_TAG, '', extension_prompt_types.IN_PROMPT, 0); - if (!settings.enabled) { + if (settings.enabled_files) { + await processFiles(chat); + } + + if (!settings.enabled_chats) { return; } @@ -182,7 +281,8 @@ async function rearrangeChat(chat) { } // Get the most relevant messages, excluding the last few - const queryHashes = (await queryCollection(chatId, queryText, settings.insert)).filter(onlyUnique); + const queryResults = await queryCollection(chatId, queryText, settings.query); + const queryHashes = queryResults.hashes.filter(onlyUnique); const queriedMessages = []; const insertedHashes = new Set(); const retainMessages = chat.slice(-settings.protect); @@ -335,7 +435,7 @@ async function deleteVectorItems(collectionId, hashes) { * @param {string} collectionId - The collection to query * @param {string} searchText - The text to query * @param {number} topK - The number of results to return - * @returns {Promise} - Hashes of the results + * @returns {Promise<{ hashes: number[], metadata: object[]}>} - Hashes of the results */ async function queryCollection(collectionId, searchText, topK) { const response = await fetch('/api/vector/query', { @@ -359,7 +459,7 @@ async function queryCollection(collectionId, searchText, topK) { async function purgeVectorIndex(collectionId) { try { - if (!settings.enabled) { + if (!settings.enabled_chats) { return; } @@ -382,19 +482,36 @@ async function purgeVectorIndex(collectionId) { } } +function toggleSettings() { + $('#vectors_files_settings').toggle(!!settings.enabled_files); + $('#vectors_chats_settings').toggle(!!settings.enabled_chats); +} + jQuery(async () => { if (!extension_settings.vectors) { extension_settings.vectors = settings; } + // Migrate from old settings + if (settings['enabled']) { + settings.enabled_chats = true; + } + Object.assign(settings, extension_settings.vectors); // Migrate from TensorFlow to Transformers settings.source = settings.source !== 'local' ? settings.source : 'transformers'; $('#extensions_settings2').append(renderExtensionTemplate(MODULE_NAME, 'settings')); - $('#vectors_enabled').prop('checked', settings.enabled).on('input', () => { - settings.enabled = $('#vectors_enabled').prop('checked'); + $('#vectors_enabled_chats').prop('checked', settings.enabled_chats).on('input', () => { + settings.enabled_chats = $('#vectors_enabled_chats').prop('checked'); Object.assign(extension_settings.vectors, settings); saveSettingsDebounced(); + toggleSettings(); + }); + $('#vectors_enabled_files').prop('checked', settings.enabled_files).on('input', () => { + settings.enabled_files = $('#vectors_enabled_files').prop('checked'); + Object.assign(extension_settings.vectors, settings); + saveSettingsDebounced(); + toggleSettings(); }); $('#vectors_source').val(settings.source).on('change', () => { settings.source = String($('#vectors_source').val()); @@ -436,6 +553,25 @@ jQuery(async () => { $('#vectors_vectorize_all').on('click', onVectorizeAllClick); + $('#vectors_size_threshold').val(settings.size_threshold).on('input', () => { + settings.size_threshold = Number($('#vectors_size_threshold').val()); + Object.assign(extension_settings.vectors, settings); + saveSettingsDebounced(); + }); + + $('#vectors_chunk_size').val(settings.chunk_size).on('input', () => { + settings.chunk_size = Number($('#vectors_chunk_size').val()); + Object.assign(extension_settings.vectors, settings); + saveSettingsDebounced(); + }); + + $('#vectors_chunk_count').val(settings.chunk_count).on('input', () => { + settings.chunk_count = Number($('#vectors_chunk_count').val()); + Object.assign(extension_settings.vectors, settings); + saveSettingsDebounced(); + }); + + toggleSettings(); eventSource.on(event_types.MESSAGE_DELETED, onChatEvent); eventSource.on(event_types.MESSAGE_EDITED, onChatEvent); eventSource.on(event_types.MESSAGE_SENT, onChatEvent); diff --git a/public/scripts/extensions/vectors/settings.html b/public/scripts/extensions/vectors/settings.html index 185a2334c..fa6c6f4c7 100644 --- a/public/scripts/extensions/vectors/settings.html +++ b/public/scripts/extensions/vectors/settings.html @@ -5,72 +5,119 @@
- - - -
-