Implement Data Bank vectors querying

This commit is contained in:
Cohee 2024-04-17 02:09:22 +03:00
parent 4665db62f4
commit 9a1ea7f226
6 changed files with 310 additions and 16 deletions

View File

@ -757,7 +757,7 @@ async function openAttachmentManager() {
*/
async function openWebpageScraper(target, callback) {
const template = $(await renderExtensionTemplateAsync('attachments', 'web-scrape', {}));
const link = await callGenericPopup(template, POPUP_TYPE.INPUT, '', { wide: false, large: false });
const link = await callGenericPopup(template, POPUP_TYPE.INPUT, '', { wide: false, large: false, okButton: 'Scrape', cancelButton: 'Cancel' });
if (!link) {
return;
@ -825,7 +825,7 @@ async function openFandomScraper(target, callback) {
output = String($(this).val());
});
const confirm = await callGenericPopup(template, POPUP_TYPE.CONFIRM, '', { wide: false, large: false });
const confirm = await callGenericPopup(template, POPUP_TYPE.CONFIRM, '', { wide: false, large: false, okButton: 'Scrape', cancelButton: 'Cancel' });
if (confirm !== POPUP_RESULT.AFFIRMATIVE) {
return;
@ -994,6 +994,19 @@ function ensureAttachmentsExist() {
}
}
/**
* Gets all currently available attachments.
* @returns {FileAttachment[]} List of attachments
*/
export function getDataBankAttachments() {
ensureAttachmentsExist();
const globalAttachments = extension_settings.attachments ?? [];
const chatAttachments = chat_metadata.attachments ?? [];
const characterAttachments = characters[this_chid]?.data?.extensions?.attachments ?? [];
return [...globalAttachments, ...chatAttachments, ...characterAttachments];
}
/**
* Probes the server to check if the Fandom plugin is available.
* @returns {Promise<boolean>} True if the plugin is available, false otherwise.

View File

@ -1,12 +1,14 @@
import { eventSource, event_types, extension_prompt_types, getCurrentChatId, getRequestHeaders, is_send_press, saveSettingsDebounced, setExtensionPrompt, substituteParams } from '../../../script.js';
import { eventSource, event_types, extension_prompt_roles, extension_prompt_types, getCurrentChatId, getRequestHeaders, is_send_press, saveSettingsDebounced, setExtensionPrompt, substituteParams } from '../../../script.js';
import { getDataBankAttachments, getFileAttachment } from '../../chats.js';
import { ModuleWorkerWrapper, extension_settings, getContext, modules, renderExtensionTemplateAsync } from '../../extensions.js';
import { collapseNewlines } from '../../power-user.js';
import { SECRET_KEYS, secret_state, writeSecret } from '../../secrets.js';
import { debounce, getStringHash as calculateHash, waitUntilCondition, onlyUnique, splitRecursive } from '../../utils.js';
import { debounce, getStringHash as calculateHash, waitUntilCondition, onlyUnique, splitRecursive, getFileText } from '../../utils.js';
const MODULE_NAME = 'vectors';
export const EXTENSION_PROMPT_TAG = '3_vectors';
export const EXTENSION_PROMPT_TAG_DB = '4_vectors_data_bank';
const settings = {
// For both
@ -17,7 +19,7 @@ const settings = {
// For chats
enabled_chats: false,
template: 'Past events: {{text}}',
template: 'Past events:\n{{text}}',
depth: 2,
position: extension_prompt_types.IN_PROMPT,
protect: 5,
@ -30,6 +32,15 @@ const settings = {
size_threshold: 10,
chunk_size: 5000,
chunk_count: 2,
// For Data Bank
size_threshold_db: 5,
chunk_size_db: 2500,
chunk_count_db: 5,
file_template_db: 'Related information:\n{{text}}',
file_position_db: extension_prompt_types.IN_PROMPT,
file_depth_db: 4,
file_depth_role_db: extension_prompt_roles.SYSTEM,
};
const moduleWorker = new ModuleWorkerWrapper(synchronizeChat);
@ -214,6 +225,34 @@ async function processFiles(chat) {
return;
}
const dataBank = getDataBankAttachments();
const dataBankCollectionIds = [];
for (const file of dataBank) {
const collectionId = `file_${getStringHash(file.url)}`;
const hashesInCollection = await getSavedHashes(collectionId);
dataBankCollectionIds.push(collectionId);
// File is already in the collection
if (hashesInCollection.length) {
continue;
}
// Download and process the file
file.text = await getFileAttachment(file.url);
console.log(`Vectors: Retrieved file ${file.name} from Data Bank`);
// Convert kilobytes to string length
const thresholdLength = settings.size_threshold_db * 1024;
// Use chunk size from settings if file is larger than threshold
const chunkSize = file.size > thresholdLength ? settings.chunk_size_db : -1;
await vectorizeFile(file.text, file.name, collectionId, chunkSize);
}
if (dataBankCollectionIds.length) {
const queryText = getQueryText(chat);
await injectDataBankChunks(queryText, dataBankCollectionIds);
}
for (const message of chat) {
// Message has no file
if (!message?.extra?.file) {
@ -240,7 +279,7 @@ async function processFiles(chat) {
// File is already in the collection
if (!hashesInCollection.length) {
await vectorizeFile(fileText, fileName, collectionId);
await vectorizeFile(fileText, fileName, collectionId, settings.chunk_size);
}
const queryText = getQueryText(chat);
@ -253,6 +292,36 @@ async function processFiles(chat) {
}
}
/**
* Inserts file chunks from the Data Bank into the prompt.
* @param {string} queryText Text to query
* @param {string[]} collectionIds File collection IDs
* @returns {Promise<void>}
*/
async function injectDataBankChunks(queryText, collectionIds) {
try {
const queryResults = await queryMultipleCollections(collectionIds, queryText, settings.chunk_count_db);
console.debug(`Vectors: Retrieved ${collectionIds.length} Data Bank collections`, queryResults);
let textResult = '';
for (const collectionId in queryResults) {
console.debug(`Vectors: Processing Data Bank collection ${collectionId}`, queryResults[collectionId]);
const metadata = queryResults[collectionId].metadata?.filter(x => x.text)?.sort((a, b) => a.index - b.index)?.map(x => x.text)?.filter(onlyUnique) || [];
textResult += metadata.join('\n') + '\n\n';
}
if (!textResult) {
console.debug('Vectors: No Data Bank chunks found');
return;
}
const insertedText = substituteParams(settings.file_template_db.replace(/{{text}}/i, textResult));
setExtensionPrompt(EXTENSION_PROMPT_TAG_DB, insertedText, settings.file_position_db, settings.file_depth_db, settings.include_wi, settings.file_depth_role_db);
} catch (error) {
console.error('Vectors: Failed to insert Data Bank chunks', error);
}
}
/**
* Retrieves file chunks from the vector index and inserts them into the chat.
* @param {string} queryText Text to query
@ -274,11 +343,12 @@ async function retrieveFileChunks(queryText, collectionId) {
* @param {string} fileText File text
* @param {string} fileName File name
* @param {string} collectionId File collection ID
* @param {number} chunkSize Chunk size
*/
async function vectorizeFile(fileText, fileName, collectionId) {
async function vectorizeFile(fileText, fileName, collectionId, chunkSize) {
try {
toastr.info('Vectorization may take some time, please wait...', `Ingesting file ${fileName}`);
const chunks = splitRecursive(fileText, settings.chunk_size);
const chunks = splitRecursive(fileText, chunkSize);
console.debug(`Vectors: Split file ${fileName} into ${chunks.length} chunks`, chunks);
const items = chunks.map((chunk, index) => ({ hash: getStringHash(chunk), text: chunk, index: index }));
@ -297,7 +367,8 @@ async function vectorizeFile(fileText, fileName, collectionId) {
async function rearrangeChat(chat) {
try {
// Clear the extension prompt
setExtensionPrompt(EXTENSION_PROMPT_TAG, '', extension_prompt_types.IN_PROMPT, 0, settings.include_wi);
setExtensionPrompt(EXTENSION_PROMPT_TAG, '', settings.position, settings.depth, settings.include_wi);
setExtensionPrompt(EXTENSION_PROMPT_TAG_DB, '', settings.file_position_db, settings.file_depth_db, settings.include_wi, settings.file_depth_role_db);
if (settings.enabled_files) {
await processFiles(chat);
@ -563,6 +634,34 @@ async function queryCollection(collectionId, searchText, topK) {
return await response.json();
}
/**
* Queries multiple collections for a given text.
* @param {string[]} collectionIds - Collection IDs to query
* @param {string} searchText - Text to query
* @param {number} topK - Number of results to return
* @returns {Promise<Record<string, { hashes: number[], metadata: object[] }>>} - Results mapped to collection IDs
*/
async function queryMultipleCollections(collectionIds, searchText, topK) {
const headers = getVectorHeaders();
const response = await fetch('/api/vector/query-multi', {
method: 'POST',
headers: headers,
body: JSON.stringify({
collectionIds: collectionIds,
searchText: searchText,
topK: topK,
source: settings.source,
}),
});
if (!response.ok) {
throw new Error('Failed to query multiple collections');
}
return await response.json();
}
/**
* Purges the vector index for a collection.
* @param {string} collectionId Collection ID to purge
@ -761,6 +860,49 @@ jQuery(async () => {
saveSettingsDebounced();
});
$('#vectors_size_threshold_db').val(settings.size_threshold_db).on('input', () => {
settings.size_threshold_db = Number($('#vectors_size_threshold_db').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_chunk_size_db').val(settings.chunk_size_db).on('input', () => {
settings.chunk_size_db = Number($('#vectors_chunk_size_db').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_chunk_count_db').val(settings.chunk_count_db).on('input', () => {
settings.chunk_count_db = Number($('#vectors_chunk_count_db').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_file_template_db').val(settings.file_template_db).on('input', () => {
settings.file_template_db = String($('#vectors_file_template_db').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$(`input[name="vectors_file_position_db"][value="${settings.file_position_db}"]`).prop('checked', true);
$('input[name="vectors_file_position_db"]').on('change', () => {
settings.file_position_db = Number($('input[name="vectors_file_position_db"]:checked').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_file_depth_db').val(settings.file_depth_db).on('input', () => {
settings.file_depth_db = Number($('#vectors_file_depth_db').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_file_depth_role_db').val(settings.file_depth_role_db).on('input', () => {
settings.file_depth_role_db = Number($('#vectors_file_depth_role_db').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
const validSecret = !!secret_state[SECRET_KEYS.NOMICAI];
const placeholder = validSecret ? '✔️ Key saved' : '❌ Missing key';
$('#api_key_nomicai').attr('placeholder', placeholder);

View File

@ -89,9 +89,11 @@
Enabled for files
</label>
<div id="vectors_files_settings">
<div class="flex-container">
<div id="vectors_files_settings" class="marginTopBot5">
<div class="flex justifyCenter" title="These settings apply to files attached directly to messages.">
<span>Message attachments</span>
</div>
<div class="flex-container marginBot5">
<div class="flex1" title="Only files past this size will be vectorized.">
<label for="vectors_size_threshold">
<small>Size threshold (KB)</small>
@ -111,6 +113,58 @@
<input id="vectors_chunk_count" type="number" class="text_pole widthUnset" min="1" max="99999" />
</div>
</div>
<div class="flex justifyCenter" title="These settings apply to files stored in the Data Bank.">
<span>Data Bank files</span>
</div>
<div class="flex-container">
<div class="flex1" title="Only files past this size will be vectorized.">
<label for="vectors_size_threshold_db">
<small>Size threshold (KB)</small>
</label>
<input id="vectors_size_threshold_db" type="number" class="text_pole widthUnset" min="1" max="99999" />
</div>
<div class="flex1" title="Chunk size for file splitting.">
<label for="vectors_chunk_size_db">
<small>Chunk size (chars)</small>
</label>
<input id="vectors_chunk_size_db" type="number" class="text_pole widthUnset" min="1" max="99999" />
</div>
<div class="flex1" title="How many chunks to retrieve when querying.">
<label for="vectors_chunk_count_db">
<small>Retrieve chunks</small>
</label>
<input id="vectors_chunk_count_db" type="number" class="text_pole widthUnset" min="1" max="99999" />
</div>
</div>
<div class="flex-container flexFlowColumn">
<label for="vectors_file_template_db">
<span>Injection Template</span>
</label>
<textarea id="vectors_file_template_db" class="margin0 text_pole textarea_compact" rows="3" placeholder="Use &lcub;&lcub;text&rcub;&rcub; macro to specify the position of retrieved text."></textarea>
<label for="vectors_file_position_db">Injection Position</label>
<div class="radio_group">
<label>
<input type="radio" name="vectors_file_position_db" value="2" />
<span>Before Main Prompt / Story String</span>
</label>
<!--Keep these as 0 and 1 to interface with the setExtensionPrompt function-->
<label>
<input type="radio" name="vectors_file_position_db" value="0" />
<span>After Main Prompt / Story String</span>
</label>
<label for="vectors_file_depth_db" title="How many messages before the current end of the chat." data-i18n="[title]How many messages before the current end of the chat.">
<input type="radio" name="vectors_file_position_db" value="1" />
<span>In-chat @ Depth</span>
<input id="vectors_file_depth_db" class="text_pole widthUnset" type="number" min="0" max="999" />
<span>as</span>
<select id="vectors_file_depth_role_db" class="text_pole widthNatural">
<option value="0">System</option>
<option value="1">User</option>
<option value="2">Assistant</option>
</select>
</label>
</div>
</div>
</div>
<hr>
@ -126,9 +180,9 @@
<div id="vectors_chats_settings">
<div id="vectors_advanced_settings">
<label for="vectors_template">
Insertion Template
Injection Template
</label>
<textarea id="vectors_template" class="text_pole textarea_compact" rows="3" placeholder="Use {{text}} macro to specify the position of retrieved text."></textarea>
<textarea id="vectors_template" class="text_pole textarea_compact" rows="3" placeholder="Use &lcub;&lcub;text&rcub;&rcub; macro to specify the position of retrieved text."></textarea>
<label for="vectors_position">Injection Position</label>
<div class="radio_group">
<label>

View File

@ -998,6 +998,15 @@ async function populateChatCompletion(prompts, chatCompletion, { bias, quietProm
}
}
// Vectors Data Bank
if (prompts.has('vectorsDataBank')) {
const vectorsDataBank = prompts.get('vectorsDataBank');
if (vectorsDataBank.position) {
chatCompletion.insert(Message.fromPrompt(vectorsDataBank), 'main', vectorsDataBank.position);
}
}
// Smart Context (ChromaDB)
if (prompts.has('smartContext')) {
const smartContext = prompts.get('smartContext');
@ -1089,6 +1098,14 @@ function preparePromptsForChatCompletion({ Scenario, charPersonality, name2, wor
position: getPromptPosition(vectorsMemory.position),
});
const vectorsDataBank = extensionPrompts['4_vectors_data_bank'];
if (vectorsDataBank && vectorsDataBank.value) systemPrompts.push({
role: getPromptRole(vectorsDataBank.role),
content: vectorsDataBank.value,
identifier: 'vectorsDataBank',
position: getPromptPosition(vectorsDataBank.position),
});
// Smart Context (ChromaDB)
const smartContext = extensionPrompts['chromadb'];
if (smartContext && smartContext.value) systemPrompts.push({

View File

@ -685,6 +685,11 @@ export function sortMoments(a, b) {
* splitRecursive('Hello, world!', 3); // ['Hel', 'lo,', 'wor', 'ld!']
*/
export function splitRecursive(input, length, delimiters = ['\n\n', '\n', ' ', '']) {
// Invalid length
if (length <= 0) {
return [input];
}
const delim = delimiters[0] ?? '';
const parts = input.split(delim);

View File

@ -5,7 +5,7 @@ const sanitize = require('sanitize-filename');
const { jsonParser } = require('../express-common');
// Don't forget to add new sources to the SOURCES array
const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm', 'togetherai', 'nomicai'];
const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm', 'togetherai', 'nomicai', 'cohere'];
/**
* Gets the vector for the given text from the given source.
@ -55,7 +55,7 @@ async function getBatchVector(source, sourceSettings, texts, directories) {
case 'togetherai':
case 'mistral':
case 'openai':
results.push(...await require('../openai-vectors').getOpenAIBatchVector(batch, source, sourceSettings.model));
results.push(...await require('../openai-vectors').getOpenAIBatchVector(batch, source, directories, sourceSettings.model));
break;
case 'transformers':
results.push(...await require('../embedding').getTransformersBatchVector(batch));
@ -155,6 +155,7 @@ async function deleteVectorItems(directories, collectionId, source, hashes) {
/**
* Gets the hashes of the items in the vector collection that match the search text
* @param {import('../users').UserDirectoryList} directories - User directories
* @param {string} collectionId - The collection ID
* @param {string} source - The source of the vector
* @param {Object} sourceSettings - Settings for the source, if it needs any
@ -172,6 +173,48 @@ async function queryCollection(directories, collectionId, source, sourceSettings
return { metadata, hashes };
}
/**
* Queries multiple collections for the given search queries. Returns the overall top K results.
* @param {import('../users').UserDirectoryList} directories - User directories
* @param {string[]} collectionIds - The collection IDs to query
* @param {string} source - The source of the vector
* @param {Object} sourceSettings - Settings for the source, if it needs any
* @param {string} searchText - The text to search for
* @param {number} topK - The number of results to return
* @returns {Promise<Record<string, { hashes: number[], metadata: object[] }>>} - The top K results from each collection
*/
async function multiQueryCollection(directories, collectionIds, source, sourceSettings, searchText, topK) {
const vector = await getVector(source, sourceSettings, searchText, directories);
const results = [];
for (const collectionId of collectionIds) {
const store = await getIndex(directories, collectionId, source);
const result = await store.queryItems(vector, topK);
results.push(...result.map(result => ({ collectionId, result })));
}
// Sort results by descending similarity
const sortedResults = results
.sort((a, b) => b.result.score - a.result.score)
.slice(0, topK);
/**
* Group the results by collection ID
* @type {Record<string, { hashes: number[], metadata: object[] }>}
*/
const groupedResults = {};
for (const result of sortedResults) {
if (!groupedResults[result.collectionId]) {
groupedResults[result.collectionId] = { hashes: [], metadata: [] };
}
groupedResults[result.collectionId].hashes.push(Number(result.result.item.metadata.hash));
groupedResults[result.collectionId].metadata.push(result.result.item.metadata);
}
return groupedResults;
}
/**
* Extracts settings for the vectorization sources from the HTTP request headers.
* @param {string} source - Which source to extract settings for.
@ -229,6 +272,26 @@ router.post('/query', jsonParser, async (req, res) => {
}
});
router.post('/query-multi', jsonParser, async (req, res) => {
try {
if (!Array.isArray(req.body.collectionIds) || !req.body.searchText) {
return res.sendStatus(400);
}
const collectionIds = req.body.collectionIds.map(x => String(x));
const searchText = String(req.body.searchText);
const topK = Number(req.body.topK) || 10;
const source = String(req.body.source) || 'transformers';
const sourceSettings = getSourceSettings(source, req);
const results = await multiQueryCollection(req.user.directories, collectionIds, source, sourceSettings, searchText, topK);
return res.json(results);
} catch (error) {
console.error(error);
return res.sendStatus(500);
}
});
router.post('/insert', jsonParser, async (req, res) => {
try {
if (!Array.isArray(req.body.items) || !req.body.collectionId) {