Vector storage file retrieval

This commit is contained in:
Cohee
2023-11-30 00:01:59 +02:00
parent e0bf2b8e3e
commit 19df1f52cd
6 changed files with 276 additions and 90 deletions

View File

@ -2,28 +2,37 @@ import { eventSource, event_types, extension_prompt_types, getCurrentChatId, get
import { ModuleWorkerWrapper, extension_settings, getContext, renderExtensionTemplate } from "../../extensions.js";
import { collapseNewlines, power_user, ui_mode } from "../../power-user.js";
import { SECRET_KEYS, secret_state } from "../../secrets.js";
import { debounce, getStringHash as calculateHash, waitUntilCondition, onlyUnique } from "../../utils.js";
import { debounce, getStringHash as calculateHash, waitUntilCondition, onlyUnique, splitRecursive } from "../../utils.js";
const MODULE_NAME = 'vectors';
export const EXTENSION_PROMPT_TAG = '3_vectors';
const settings = {
enabled: false,
// For both
source: 'transformers',
// For chats
enabled_chats: false,
template: `Past events: {{text}}`,
depth: 2,
position: extension_prompt_types.IN_PROMPT,
protect: 5,
insert: 3,
query: 2,
// For files
enabled_files: false,
size_threshold: 5,
chunk_size: 1000,
chunk_count: 4,
};
const moduleWorker = new ModuleWorkerWrapper(synchronizeChat);
async function onVectorizeAllClick() {
try {
if (!settings.enabled) {
if (!settings.enabled_chats) {
return;
}
@ -78,7 +87,7 @@ async function onVectorizeAllClick() {
let syncBlocked = false;
async function synchronizeChat(batchSize = 5) {
if (!settings.enabled) {
if (!settings.enabled_chats) {
return -1;
}
@ -99,7 +108,7 @@ async function synchronizeChat(batchSize = 5) {
return -1;
}
const hashedMessages = context.chat.filter(x => !x.is_system).map(x => ({ text: String(x.mes), hash: getStringHash(x.mes) }));
const hashedMessages = context.chat.filter(x => !x.is_system).map(x => ({ text: String(x.mes), hash: getStringHash(x.mes), index: context.chat.indexOf(x) }));
const hashesInCollection = await getSavedHashes(chatId);
const newVectorItems = hashedMessages.filter(x => !hashesInCollection.includes(x.hash));
@ -149,6 +158,92 @@ function getStringHash(str) {
return hash;
}
/**
* Retrieves files from the chat and inserts them into the vector index.
* @param {object[]} chat Array of chat messages
* @returns {Promise<void>}
*/
async function processFiles(chat) {
try {
if (!settings.enabled_files) {
return;
}
for (const message of chat) {
// Message has no file
if (!message?.extra?.file) {
continue;
}
// Trim file inserted by the script
const fileText = message.mes.substring(message.extra.fileStart).trim();
// Convert kilobytes to string length
const thresholdLength = settings.size_threshold * 1024;
// File is too small
if (fileText.length < thresholdLength) {
continue;
}
message.mes = message.mes.substring(0, message.extra.fileStart);
const fileName = message.extra.file.name;
const collectionId = `file_${getStringHash(fileName)}`;
const hashesInCollection = await getSavedHashes(collectionId);
// File is already in the collection
if (!hashesInCollection.length) {
await vectorizeFile(fileText, fileName, collectionId);
}
const queryText = getQueryText(chat);
const fileChunks = await retrieveFileChunks(queryText, collectionId);
message.mes += '\n\n' + fileChunks;
}
} catch (error) {
console.error('Vectors: Failed to retrieve files', error);
}
}
/**
* Retrieves file chunks from the vector index and inserts them into the chat.
* @param {string} queryText Text to query
* @param {string} collectionId File collection ID
* @returns {Promise<string>} Retrieved file text
*/
async function retrieveFileChunks(queryText, collectionId) {
console.debug(`Vectors: Retrieving file chunks for collection ${collectionId}`, queryText);
const queryResults = await queryCollection(collectionId, queryText, settings.chunk_count);
console.debug(`Vectors: Retrieved ${queryResults.hashes.length} file chunks for collection ${collectionId}`, queryResults);
const metadata = queryResults.metadata.filter(x => x.text).sort((a, b) => a.index - b.index).map(x => x.text);
const fileText = metadata.join('\n');
return fileText;
}
/**
* Vectorizes a file and inserts it into the vector index.
* @param {string} fileText File text
* @param {string} fileName File name
* @param {string} collectionId File collection ID
*/
async function vectorizeFile(fileText, fileName, collectionId) {
try {
toastr.info("Vectorization may take some time, please wait...", `Ingesting file ${fileName}`);
const chunks = splitRecursive(fileText, settings.chunk_size);
console.debug(`Vectors: Split file ${fileName} into ${chunks.length} chunks`, chunks);
const items = chunks.map((chunk, index) => ({ hash: getStringHash(chunk), text: chunk, index: index }));
await insertVectorItems(collectionId, items);
console.log(`Vectors: Inserted ${chunks.length} vector items for file ${fileName} into ${collectionId}`);
} catch (error) {
console.error('Vectors: Failed to vectorize file', error);
}
}
/**
* Removes the most relevant messages from the chat and displays them in the extension prompt
* @param {object[]} chat Array of chat messages
@ -158,7 +253,11 @@ async function rearrangeChat(chat) {
// Clear the extension prompt
setExtensionPrompt(EXTENSION_PROMPT_TAG, '', extension_prompt_types.IN_PROMPT, 0);
if (!settings.enabled) {
if (settings.enabled_files) {
await processFiles(chat);
}
if (!settings.enabled_chats) {
return;
}
@ -182,7 +281,8 @@ async function rearrangeChat(chat) {
}
// Get the most relevant messages, excluding the last few
const queryHashes = (await queryCollection(chatId, queryText, settings.insert)).filter(onlyUnique);
const queryResults = await queryCollection(chatId, queryText, settings.query);
const queryHashes = queryResults.hashes.filter(onlyUnique);
const queriedMessages = [];
const insertedHashes = new Set();
const retainMessages = chat.slice(-settings.protect);
@ -335,7 +435,7 @@ async function deleteVectorItems(collectionId, hashes) {
* @param {string} collectionId - The collection to query
* @param {string} searchText - The text to query
* @param {number} topK - The number of results to return
* @returns {Promise<number[]>} - Hashes of the results
* @returns {Promise<{ hashes: number[], metadata: object[]}>} - Hashes of the results
*/
async function queryCollection(collectionId, searchText, topK) {
const response = await fetch('/api/vector/query', {
@ -359,7 +459,7 @@ async function queryCollection(collectionId, searchText, topK) {
async function purgeVectorIndex(collectionId) {
try {
if (!settings.enabled) {
if (!settings.enabled_chats) {
return;
}
@ -382,19 +482,36 @@ async function purgeVectorIndex(collectionId) {
}
}
function toggleSettings() {
$('#vectors_files_settings').toggle(!!settings.enabled_files);
$('#vectors_chats_settings').toggle(!!settings.enabled_chats);
}
jQuery(async () => {
if (!extension_settings.vectors) {
extension_settings.vectors = settings;
}
// Migrate from old settings
if (settings['enabled']) {
settings.enabled_chats = true;
}
Object.assign(settings, extension_settings.vectors);
// Migrate from TensorFlow to Transformers
settings.source = settings.source !== 'local' ? settings.source : 'transformers';
$('#extensions_settings2').append(renderExtensionTemplate(MODULE_NAME, 'settings'));
$('#vectors_enabled').prop('checked', settings.enabled).on('input', () => {
settings.enabled = $('#vectors_enabled').prop('checked');
$('#vectors_enabled_chats').prop('checked', settings.enabled_chats).on('input', () => {
settings.enabled_chats = $('#vectors_enabled_chats').prop('checked');
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
toggleSettings();
});
$('#vectors_enabled_files').prop('checked', settings.enabled_files).on('input', () => {
settings.enabled_files = $('#vectors_enabled_files').prop('checked');
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
toggleSettings();
});
$('#vectors_source').val(settings.source).on('change', () => {
settings.source = String($('#vectors_source').val());
@ -436,6 +553,25 @@ jQuery(async () => {
$('#vectors_vectorize_all').on('click', onVectorizeAllClick);
$('#vectors_size_threshold').val(settings.size_threshold).on('input', () => {
settings.size_threshold = Number($('#vectors_size_threshold').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_chunk_size').val(settings.chunk_size).on('input', () => {
settings.chunk_size = Number($('#vectors_chunk_size').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_chunk_count').val(settings.chunk_count).on('input', () => {
settings.chunk_count = Number($('#vectors_chunk_count').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
toggleSettings();
eventSource.on(event_types.MESSAGE_DELETED, onChatEvent);
eventSource.on(event_types.MESSAGE_EDITED, onChatEvent);
eventSource.on(event_types.MESSAGE_SENT, onChatEvent);