Vector storage file retrieval

This commit is contained in:
Cohee 2023-11-30 00:01:59 +02:00
parent e0bf2b8e3e
commit 19df1f52cd
6 changed files with 276 additions and 90 deletions

View File

@ -9062,6 +9062,7 @@ jQuery(async function () {
hideStopButton();
}
eventSource.emit(event_types.GENERATION_STOPPED);
activateSendButtons();
});
$('.drawer-toggle').on('click', function () {

View File

@ -119,7 +119,7 @@ export async function populateFileAttachment(message, inputId = 'file_form_input
const fileText = await converter(file);
base64Data = window.btoa(unescape(encodeURIComponent(fileText)));
} catch (error) {
toastr.error(error, 'Could not convert file');
toastr.error(String(error), 'Could not convert file');
console.error('Could not convert file', error);
}
}
@ -169,7 +169,7 @@ export async function uploadFileAttachment(fileName, base64Data) {
const responseData = await result.json();
return responseData.path.replace(/\\/g, '/');
} catch (error) {
toastr.error(error, 'Could not upload file');
toastr.error(String(error), 'Could not upload file');
console.error('Could not upload file', error);
}
}

View File

@ -2,28 +2,37 @@ import { eventSource, event_types, extension_prompt_types, getCurrentChatId, get
import { ModuleWorkerWrapper, extension_settings, getContext, renderExtensionTemplate } from "../../extensions.js";
import { collapseNewlines, power_user, ui_mode } from "../../power-user.js";
import { SECRET_KEYS, secret_state } from "../../secrets.js";
import { debounce, getStringHash as calculateHash, waitUntilCondition, onlyUnique } from "../../utils.js";
import { debounce, getStringHash as calculateHash, waitUntilCondition, onlyUnique, splitRecursive } from "../../utils.js";
const MODULE_NAME = 'vectors';
export const EXTENSION_PROMPT_TAG = '3_vectors';
const settings = {
enabled: false,
// For both
source: 'transformers',
// For chats
enabled_chats: false,
template: `Past events: {{text}}`,
depth: 2,
position: extension_prompt_types.IN_PROMPT,
protect: 5,
insert: 3,
query: 2,
// For files
enabled_files: false,
size_threshold: 5,
chunk_size: 1000,
chunk_count: 4,
};
const moduleWorker = new ModuleWorkerWrapper(synchronizeChat);
async function onVectorizeAllClick() {
try {
if (!settings.enabled) {
if (!settings.enabled_chats) {
return;
}
@ -78,7 +87,7 @@ async function onVectorizeAllClick() {
let syncBlocked = false;
async function synchronizeChat(batchSize = 5) {
if (!settings.enabled) {
if (!settings.enabled_chats) {
return -1;
}
@ -99,7 +108,7 @@ async function synchronizeChat(batchSize = 5) {
return -1;
}
const hashedMessages = context.chat.filter(x => !x.is_system).map(x => ({ text: String(x.mes), hash: getStringHash(x.mes) }));
const hashedMessages = context.chat.filter(x => !x.is_system).map(x => ({ text: String(x.mes), hash: getStringHash(x.mes), index: context.chat.indexOf(x) }));
const hashesInCollection = await getSavedHashes(chatId);
const newVectorItems = hashedMessages.filter(x => !hashesInCollection.includes(x.hash));
@ -149,6 +158,92 @@ function getStringHash(str) {
return hash;
}
/**
* Retrieves files from the chat and inserts them into the vector index.
* @param {object[]} chat Array of chat messages
* @returns {Promise<void>}
*/
async function processFiles(chat) {
try {
if (!settings.enabled_files) {
return;
}
for (const message of chat) {
// Message has no file
if (!message?.extra?.file) {
continue;
}
// Trim file inserted by the script
const fileText = message.mes.substring(message.extra.fileStart).trim();
// Convert kilobytes to string length
const thresholdLength = settings.size_threshold * 1024;
// File is too small
if (fileText.length < thresholdLength) {
continue;
}
message.mes = message.mes.substring(0, message.extra.fileStart);
const fileName = message.extra.file.name;
const collectionId = `file_${getStringHash(fileName)}`;
const hashesInCollection = await getSavedHashes(collectionId);
// File is already in the collection
if (!hashesInCollection.length) {
await vectorizeFile(fileText, fileName, collectionId);
}
const queryText = getQueryText(chat);
const fileChunks = await retrieveFileChunks(queryText, collectionId);
message.mes += '\n\n' + fileChunks;
}
} catch (error) {
console.error('Vectors: Failed to retrieve files', error);
}
}
/**
* Retrieves file chunks from the vector index and inserts them into the chat.
* @param {string} queryText Text to query
* @param {string} collectionId File collection ID
* @returns {Promise<string>} Retrieved file text
*/
async function retrieveFileChunks(queryText, collectionId) {
console.debug(`Vectors: Retrieving file chunks for collection ${collectionId}`, queryText);
const queryResults = await queryCollection(collectionId, queryText, settings.chunk_count);
console.debug(`Vectors: Retrieved ${queryResults.hashes.length} file chunks for collection ${collectionId}`, queryResults);
const metadata = queryResults.metadata.filter(x => x.text).sort((a, b) => a.index - b.index).map(x => x.text);
const fileText = metadata.join('\n');
return fileText;
}
/**
* Vectorizes a file and inserts it into the vector index.
* @param {string} fileText File text
* @param {string} fileName File name
* @param {string} collectionId File collection ID
*/
async function vectorizeFile(fileText, fileName, collectionId) {
try {
toastr.info("Vectorization may take some time, please wait...", `Ingesting file ${fileName}`);
const chunks = splitRecursive(fileText, settings.chunk_size);
console.debug(`Vectors: Split file ${fileName} into ${chunks.length} chunks`, chunks);
const items = chunks.map((chunk, index) => ({ hash: getStringHash(chunk), text: chunk, index: index }));
await insertVectorItems(collectionId, items);
console.log(`Vectors: Inserted ${chunks.length} vector items for file ${fileName} into ${collectionId}`);
} catch (error) {
console.error('Vectors: Failed to vectorize file', error);
}
}
/**
* Removes the most relevant messages from the chat and displays them in the extension prompt
* @param {object[]} chat Array of chat messages
@ -158,7 +253,11 @@ async function rearrangeChat(chat) {
// Clear the extension prompt
setExtensionPrompt(EXTENSION_PROMPT_TAG, '', extension_prompt_types.IN_PROMPT, 0);
if (!settings.enabled) {
if (settings.enabled_files) {
await processFiles(chat);
}
if (!settings.enabled_chats) {
return;
}
@ -182,7 +281,8 @@ async function rearrangeChat(chat) {
}
// Get the most relevant messages, excluding the last few
const queryHashes = (await queryCollection(chatId, queryText, settings.insert)).filter(onlyUnique);
const queryResults = await queryCollection(chatId, queryText, settings.query);
const queryHashes = queryResults.hashes.filter(onlyUnique);
const queriedMessages = [];
const insertedHashes = new Set();
const retainMessages = chat.slice(-settings.protect);
@ -335,7 +435,7 @@ async function deleteVectorItems(collectionId, hashes) {
* @param {string} collectionId - The collection to query
* @param {string} searchText - The text to query
* @param {number} topK - The number of results to return
* @returns {Promise<number[]>} - Hashes of the results
* @returns {Promise<{ hashes: number[], metadata: object[]}>} - Hashes of the results
*/
async function queryCollection(collectionId, searchText, topK) {
const response = await fetch('/api/vector/query', {
@ -359,7 +459,7 @@ async function queryCollection(collectionId, searchText, topK) {
async function purgeVectorIndex(collectionId) {
try {
if (!settings.enabled) {
if (!settings.enabled_chats) {
return;
}
@ -382,19 +482,36 @@ async function purgeVectorIndex(collectionId) {
}
}
function toggleSettings() {
$('#vectors_files_settings').toggle(!!settings.enabled_files);
$('#vectors_chats_settings').toggle(!!settings.enabled_chats);
}
jQuery(async () => {
if (!extension_settings.vectors) {
extension_settings.vectors = settings;
}
// Migrate from old settings
if (settings['enabled']) {
settings.enabled_chats = true;
}
Object.assign(settings, extension_settings.vectors);
// Migrate from TensorFlow to Transformers
settings.source = settings.source !== 'local' ? settings.source : 'transformers';
$('#extensions_settings2').append(renderExtensionTemplate(MODULE_NAME, 'settings'));
$('#vectors_enabled').prop('checked', settings.enabled).on('input', () => {
settings.enabled = $('#vectors_enabled').prop('checked');
$('#vectors_enabled_chats').prop('checked', settings.enabled_chats).on('input', () => {
settings.enabled_chats = $('#vectors_enabled_chats').prop('checked');
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
toggleSettings();
});
$('#vectors_enabled_files').prop('checked', settings.enabled_files).on('input', () => {
settings.enabled_files = $('#vectors_enabled_files').prop('checked');
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
toggleSettings();
});
$('#vectors_source').val(settings.source).on('change', () => {
settings.source = String($('#vectors_source').val());
@ -436,6 +553,25 @@ jQuery(async () => {
$('#vectors_vectorize_all').on('click', onVectorizeAllClick);
$('#vectors_size_threshold').val(settings.size_threshold).on('input', () => {
settings.size_threshold = Number($('#vectors_size_threshold').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_chunk_size').val(settings.chunk_size).on('input', () => {
settings.chunk_size = Number($('#vectors_chunk_size').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_chunk_count').val(settings.chunk_count).on('input', () => {
settings.chunk_count = Number($('#vectors_chunk_count').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
toggleSettings();
eventSource.on(event_types.MESSAGE_DELETED, onChatEvent);
eventSource.on(event_types.MESSAGE_EDITED, onChatEvent);
eventSource.on(event_types.MESSAGE_SENT, onChatEvent);

View File

@ -5,72 +5,119 @@
<div class="inline-drawer-icon fa-solid fa-circle-chevron-down down"></div>
</div>
<div class="inline-drawer-content">
<label class="checkbox_label" for="vectors_enabled">
<input id="vectors_enabled" type="checkbox" class="checkbox">
Enabled
</label>
<label for="vectors_source">
Vectorization Source
</label>
<select id="vectors_source" class="select">
<option value="transformers">Local (Transformers)</option>
<option value="openai">OpenAI</option>
<option value="palm">Google MakerSuite (PaLM)</option>
</select>
<div id="vectors_advanced_settings" data-newbie-hidden>
<label for="vectors_template">
Insertion Template
<div class="flex-container flexFlowColumn">
<label for="vectors_source">
Vectorization Source
</label>
<textarea id="vectors_template" class="text_pole textarea_compact autoSetHeight" rows="2" placeholder="Use {{text}} macro to specify the position of retrieved text."></textarea>
<label for="vectors_position">Injection Position</label>
<div class="radio_group">
<label>
<input type="radio" name="vectors_position" value="2" />
Before Main Prompt / Story String
</label>
<!--Keep these as 0 and 1 to interface with the setExtensionPrompt function-->
<label>
<input type="radio" name="vectors_position" value="0" />
After Main Prompt / Story String
</label>
<label>
<input type="radio" name="vectors_position" value="1" />
In-chat @ Depth <input id="vectors_depth" class="text_pole widthUnset" type="number" min="0" max="999" />
</label>
</div>
<select id="vectors_source" class="text_pole">
<option value="transformers">Local (Transformers)</option>
<option value="openai">OpenAI</option>
<option value="palm">Google MakerSuite (PaLM)</option>
</select>
</div>
<div class="flex-container flexFlowColumn" title="How many last messages will be matched for relevance.">
<label for="vectors_query">
<span>Query messages</span>
</label>
<input type="number" id="vectors_query" class="text_pole widthUnset" min="1" max="99" />
</div>
<hr>
<h4>
File vectorization settings
</h4>
<label class="checkbox_label" for="vectors_enabled_files">
<input id="vectors_enabled_files" type="checkbox" class="checkbox">
Enabled for files
</label>
<div id="vectors_files_settings">
<div class="flex-container">
<div class="flex1" title="Prevents last N messages from being placed out of order.">
<label for="vectors_protect">
<small>Retain#</small>
<div class="flex1" title="Only files past this size will be vectorized.">
<label for="vectors_size_threshold">
<small>Size threshold (KB)</small>
</label>
<input type="number" id="vectors_protect" class="text_pole widthUnset" min="1" max="99" />
<input id="vectors_size_threshold" type="number" class="text_pole widthUnset" min="1" max="99999" />
</div>
<div class="flex1" title="How many last messages will be matched for relevance.">
<label for="vectors_query">
<small>Query#</small>
<div class="flex1" title="Chunk size for file splitting.">
<label for="vectors_chunk_size">
<small>Chunk size (chars)</small>
</label>
<input type="number" id="vectors_query" class="text_pole widthUnset" min="1" max="99" />
<input id="vectors_chunk_size" type="number" class="text_pole widthUnset" min="1" max="99999" />
</div>
<div class="flex1" title="How many past messages to insert as memories.">
<label for="vectors_insert">
<small>Insert#</small>
<div class="flex1" title="How many chunks to retrieve when querying.">
<label for="vectors_chunk_count">
<small>Retrieve chunks</small>
</label>
<input type="number" id="vectors_insert" class="text_pole widthUnset" min="1" max="99" />
<input id="vectors_chunk_count" type="number" class="text_pole widthUnset" min="1" max="99999" />
</div>
</div>
</div>
<small>
Old messages are vectorized gradually as you chat.
To process all previous messages, click the button below.
</small>
<div id="vectors_vectorize_all" class="menu_button menu_button_icon">
Vectorize All
</div>
<div id="vectorize_progress" style="display: none;">
<hr>
<h4>
Chat vectorization settings
</h4>
<label class="checkbox_label" for="vectors_enabled_chats">
<input id="vectors_enabled_chats" type="checkbox" class="checkbox">
Enabled for chat messages
</label>
<div id="vectors_chats_settings">
<div id="vectors_advanced_settings" data-newbie-hidden>
<label for="vectors_template">
Insertion Template
</label>
<textarea id="vectors_template" class="text_pole textarea_compact" rows="3" placeholder="Use {{text}} macro to specify the position of retrieved text."></textarea>
<label for="vectors_position">Injection Position</label>
<div class="radio_group">
<label>
<input type="radio" name="vectors_position" value="2" />
Before Main Prompt / Story String
</label>
<!--Keep these as 0 and 1 to interface with the setExtensionPrompt function-->
<label>
<input type="radio" name="vectors_position" value="0" />
After Main Prompt / Story String
</label>
<label>
<input type="radio" name="vectors_position" value="1" />
In-chat @ Depth <input id="vectors_depth" class="text_pole widthUnset" type="number" min="0" max="999" />
</label>
</div>
<div class="flex-container">
<div class="flex1" title="Prevents last N messages from being placed out of order.">
<label for="vectors_protect">
<small>Retain#</small>
</label>
<input type="number" id="vectors_protect" class="text_pole widthUnset" min="1" max="99" />
</div>
<div class="flex1" title="How many past messages to insert as memories.">
<label for="vectors_insert">
<small>Insert#</small>
</label>
<input type="number" id="vectors_insert" class="text_pole widthUnset" min="1" max="99" />
</div>
</div>
</div>
<small>
Processed <span id="vectorize_progress_percent">0</span>% of messages.
ETA: <span id="vectorize_progress_eta">...</span> seconds.
Old messages are vectorized gradually as you chat.
To process all previous messages, click the button below.
</small>
<div id="vectors_vectorize_all" class="menu_button menu_button_icon">
Vectorize All
</div>
<div id="vectorize_progress" style="display: none;">
<small>
Processed <span id="vectorize_progress_percent">0</span>% of messages.
ETA: <span id="vectorize_progress_eta">...</span> seconds.
</small>
</div>
</div>
</div>
</div>

View File

@ -213,8 +213,8 @@ var chatsPath = 'public/chats/';
const SETTINGS_FILE = './public/settings.json';
const AVATAR_WIDTH = 400;
const AVATAR_HEIGHT = 600;
const jsonParser = express.json({ limit: '100mb' });
const urlencodedParser = express.urlencoded({ extended: true, limit: '100mb' });
const jsonParser = express.json({ limit: '200mb' });
const urlencodedParser = express.urlencoded({ extended: true, limit: '200mb' });
const { DIRECTORIES, UPLOADS_PATH, PALM_SAFETY } = require('./src/constants');
const { TavernCardValidator } = require("./src/validator/TavernCardValidator");

View File

@ -30,34 +30,35 @@ async function getVector(source, text) {
* @returns {Promise<vectra.LocalIndex>} - The index for the collection
*/
async function getIndex(collectionId, source, create = true) {
const index = new vectra.LocalIndex(path.join(process.cwd(), 'vectors', sanitize(source), sanitize(collectionId)));
const store = new vectra.LocalIndex(path.join(process.cwd(), 'vectors', sanitize(source), sanitize(collectionId)));
if (create && !await index.isIndexCreated()) {
await index.createIndex();
if (create && !await store.isIndexCreated()) {
await store.createIndex();
}
return index;
return store;
}
/**
* Inserts items into the vector collection
* @param {string} collectionId - The collection ID
* @param {string} source - The source of the vector
* @param {{ hash: number; text: string; }[]} items - The items to insert
* @param {{ hash: number; text: string; index: number; }[]} items - The items to insert
*/
async function insertVectorItems(collectionId, source, items) {
const index = await getIndex(collectionId, source);
const store = await getIndex(collectionId, source);
await index.beginUpdate();
await store.beginUpdate();
for (const item of items) {
const text = item.text;
const hash = item.hash;
const index = item.index;
const vector = await getVector(source, text);
await index.upsertItem({ vector: vector, metadata: { hash, text } });
await store.upsertItem({ vector: vector, metadata: { hash, text, index } });
}
await index.endUpdate();
await store.endUpdate();
}
/**
@ -67,9 +68,9 @@ async function insertVectorItems(collectionId, source, items) {
* @returns {Promise<number[]>} - The hashes of the items in the collection
*/
async function getSavedHashes(collectionId, source) {
const index = await getIndex(collectionId, source);
const store = await getIndex(collectionId, source);
const items = await index.listItems();
const items = await store.listItems();
const hashes = items.map(x => Number(x.metadata.hash));
return hashes;
@ -82,16 +83,16 @@ async function getSavedHashes(collectionId, source) {
* @param {number[]} hashes - The hashes of the items to delete
*/
async function deleteVectorItems(collectionId, source, hashes) {
const index = await getIndex(collectionId, source);
const items = await index.listItemsByMetadata({ hash: { '$in': hashes } });
const store = await getIndex(collectionId, source);
const items = await store.listItemsByMetadata({ hash: { '$in': hashes } });
await index.beginUpdate();
await store.beginUpdate();
for (const item of items) {
await index.deleteItem(item.id);
await store.deleteItem(item.id);
}
await index.endUpdate();
await store.endUpdate();
}
/**
@ -100,15 +101,16 @@ async function deleteVectorItems(collectionId, source, hashes) {
* @param {string} source - The source of the vector
* @param {string} searchText - The text to search for
* @param {number} topK - The number of results to return
* @returns {Promise<number[]>} - The hashes of the items that match the search text
* @returns {Promise<{hashes: number[], metadata: object[]}>} - The metadata of the items that match the search text
*/
async function queryCollection(collectionId, source, searchText, topK) {
const index = await getIndex(collectionId, source);
const store = await getIndex(collectionId, source);
const vector = await getVector(source, searchText);
const result = await index.queryItems(vector, topK);
const result = await store.queryItems(vector, topK);
const metadata = result.map(x => x.item.metadata);
const hashes = result.map(x => Number(x.item.metadata.hash));
return hashes;
return { metadata, hashes };
}
/**
@ -143,7 +145,7 @@ async function registerEndpoints(app, jsonParser) {
}
const collectionId = String(req.body.collectionId);
const items = req.body.items.map(x => ({ hash: x.hash, text: x.text }));
const items = req.body.items.map(x => ({ hash: x.hash, text: x.text, index: x.index }));
const source = String(req.body.source) || 'transformers';
await insertVectorItems(collectionId, source, items);