Implement Data Bank vectors querying

This commit is contained in:
Cohee
2024-04-17 02:09:22 +03:00
parent 4665db62f4
commit 9a1ea7f226
6 changed files with 310 additions and 16 deletions

View File

@@ -1,12 +1,14 @@
import { eventSource, event_types, extension_prompt_types, getCurrentChatId, getRequestHeaders, is_send_press, saveSettingsDebounced, setExtensionPrompt, substituteParams } from '../../../script.js';
import { eventSource, event_types, extension_prompt_roles, extension_prompt_types, getCurrentChatId, getRequestHeaders, is_send_press, saveSettingsDebounced, setExtensionPrompt, substituteParams } from '../../../script.js';
import { getDataBankAttachments, getFileAttachment } from '../../chats.js';
import { ModuleWorkerWrapper, extension_settings, getContext, modules, renderExtensionTemplateAsync } from '../../extensions.js';
import { collapseNewlines } from '../../power-user.js';
import { SECRET_KEYS, secret_state, writeSecret } from '../../secrets.js';
import { debounce, getStringHash as calculateHash, waitUntilCondition, onlyUnique, splitRecursive } from '../../utils.js';
import { debounce, getStringHash as calculateHash, waitUntilCondition, onlyUnique, splitRecursive, getFileText } from '../../utils.js';
const MODULE_NAME = 'vectors';
export const EXTENSION_PROMPT_TAG = '3_vectors';
export const EXTENSION_PROMPT_TAG_DB = '4_vectors_data_bank';
const settings = {
// For both
@@ -17,7 +19,7 @@ const settings = {
// For chats
enabled_chats: false,
template: 'Past events: {{text}}',
template: 'Past events:\n{{text}}',
depth: 2,
position: extension_prompt_types.IN_PROMPT,
protect: 5,
@@ -30,6 +32,15 @@ const settings = {
size_threshold: 10,
chunk_size: 5000,
chunk_count: 2,
// For Data Bank
size_threshold_db: 5,
chunk_size_db: 2500,
chunk_count_db: 5,
file_template_db: 'Related information:\n{{text}}',
file_position_db: extension_prompt_types.IN_PROMPT,
file_depth_db: 4,
file_depth_role_db: extension_prompt_roles.SYSTEM,
};
const moduleWorker = new ModuleWorkerWrapper(synchronizeChat);
@@ -214,6 +225,34 @@ async function processFiles(chat) {
return;
}
const dataBank = getDataBankAttachments();
const dataBankCollectionIds = [];
for (const file of dataBank) {
const collectionId = `file_${getStringHash(file.url)}`;
const hashesInCollection = await getSavedHashes(collectionId);
dataBankCollectionIds.push(collectionId);
// File is already in the collection
if (hashesInCollection.length) {
continue;
}
// Download and process the file
file.text = await getFileAttachment(file.url);
console.log(`Vectors: Retrieved file ${file.name} from Data Bank`);
// Convert kilobytes to string length
const thresholdLength = settings.size_threshold_db * 1024;
// Use chunk size from settings if file is larger than threshold
const chunkSize = file.size > thresholdLength ? settings.chunk_size_db : -1;
await vectorizeFile(file.text, file.name, collectionId, chunkSize);
}
if (dataBankCollectionIds.length) {
const queryText = getQueryText(chat);
await injectDataBankChunks(queryText, dataBankCollectionIds);
}
for (const message of chat) {
// Message has no file
if (!message?.extra?.file) {
@@ -240,7 +279,7 @@ async function processFiles(chat) {
// File is already in the collection
if (!hashesInCollection.length) {
await vectorizeFile(fileText, fileName, collectionId);
await vectorizeFile(fileText, fileName, collectionId, settings.chunk_size);
}
const queryText = getQueryText(chat);
@@ -253,6 +292,36 @@ async function processFiles(chat) {
}
}
/**
* Inserts file chunks from the Data Bank into the prompt.
* @param {string} queryText Text to query
* @param {string[]} collectionIds File collection IDs
* @returns {Promise<void>}
*/
async function injectDataBankChunks(queryText, collectionIds) {
try {
const queryResults = await queryMultipleCollections(collectionIds, queryText, settings.chunk_count_db);
console.debug(`Vectors: Retrieved ${collectionIds.length} Data Bank collections`, queryResults);
let textResult = '';
for (const collectionId in queryResults) {
console.debug(`Vectors: Processing Data Bank collection ${collectionId}`, queryResults[collectionId]);
const metadata = queryResults[collectionId].metadata?.filter(x => x.text)?.sort((a, b) => a.index - b.index)?.map(x => x.text)?.filter(onlyUnique) || [];
textResult += metadata.join('\n') + '\n\n';
}
if (!textResult) {
console.debug('Vectors: No Data Bank chunks found');
return;
}
const insertedText = substituteParams(settings.file_template_db.replace(/{{text}}/i, textResult));
setExtensionPrompt(EXTENSION_PROMPT_TAG_DB, insertedText, settings.file_position_db, settings.file_depth_db, settings.include_wi, settings.file_depth_role_db);
} catch (error) {
console.error('Vectors: Failed to insert Data Bank chunks', error);
}
}
/**
* Retrieves file chunks from the vector index and inserts them into the chat.
* @param {string} queryText Text to query
@@ -274,11 +343,12 @@ async function retrieveFileChunks(queryText, collectionId) {
* @param {string} fileText File text
* @param {string} fileName File name
* @param {string} collectionId File collection ID
* @param {number} chunkSize Chunk size
*/
async function vectorizeFile(fileText, fileName, collectionId) {
async function vectorizeFile(fileText, fileName, collectionId, chunkSize) {
try {
toastr.info('Vectorization may take some time, please wait...', `Ingesting file ${fileName}`);
const chunks = splitRecursive(fileText, settings.chunk_size);
const chunks = splitRecursive(fileText, chunkSize);
console.debug(`Vectors: Split file ${fileName} into ${chunks.length} chunks`, chunks);
const items = chunks.map((chunk, index) => ({ hash: getStringHash(chunk), text: chunk, index: index }));
@@ -297,7 +367,8 @@ async function vectorizeFile(fileText, fileName, collectionId) {
async function rearrangeChat(chat) {
try {
// Clear the extension prompt
setExtensionPrompt(EXTENSION_PROMPT_TAG, '', extension_prompt_types.IN_PROMPT, 0, settings.include_wi);
setExtensionPrompt(EXTENSION_PROMPT_TAG, '', settings.position, settings.depth, settings.include_wi);
setExtensionPrompt(EXTENSION_PROMPT_TAG_DB, '', settings.file_position_db, settings.file_depth_db, settings.include_wi, settings.file_depth_role_db);
if (settings.enabled_files) {
await processFiles(chat);
@@ -563,6 +634,34 @@ async function queryCollection(collectionId, searchText, topK) {
return await response.json();
}
/**
* Queries multiple collections for a given text.
* @param {string[]} collectionIds - Collection IDs to query
* @param {string} searchText - Text to query
* @param {number} topK - Number of results to return
* @returns {Promise<Record<string, { hashes: number[], metadata: object[] }>>} - Results mapped to collection IDs
*/
async function queryMultipleCollections(collectionIds, searchText, topK) {
const headers = getVectorHeaders();
const response = await fetch('/api/vector/query-multi', {
method: 'POST',
headers: headers,
body: JSON.stringify({
collectionIds: collectionIds,
searchText: searchText,
topK: topK,
source: settings.source,
}),
});
if (!response.ok) {
throw new Error('Failed to query multiple collections');
}
return await response.json();
}
/**
* Purges the vector index for a collection.
* @param {string} collectionId Collection ID to purge
@@ -761,6 +860,49 @@ jQuery(async () => {
saveSettingsDebounced();
});
$('#vectors_size_threshold_db').val(settings.size_threshold_db).on('input', () => {
settings.size_threshold_db = Number($('#vectors_size_threshold_db').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_chunk_size_db').val(settings.chunk_size_db).on('input', () => {
settings.chunk_size_db = Number($('#vectors_chunk_size_db').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_chunk_count_db').val(settings.chunk_count_db).on('input', () => {
settings.chunk_count_db = Number($('#vectors_chunk_count_db').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_file_template_db').val(settings.file_template_db).on('input', () => {
settings.file_template_db = String($('#vectors_file_template_db').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$(`input[name="vectors_file_position_db"][value="${settings.file_position_db}"]`).prop('checked', true);
$('input[name="vectors_file_position_db"]').on('change', () => {
settings.file_position_db = Number($('input[name="vectors_file_position_db"]:checked').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_file_depth_db').val(settings.file_depth_db).on('input', () => {
settings.file_depth_db = Number($('#vectors_file_depth_db').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_file_depth_role_db').val(settings.file_depth_role_db).on('input', () => {
settings.file_depth_role_db = Number($('#vectors_file_depth_role_db').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
const validSecret = !!secret_state[SECRET_KEYS.NOMICAI];
const placeholder = validSecret ? '✔️ Key saved' : '❌ Missing key';
$('#api_key_nomicai').attr('placeholder', placeholder);