diff --git a/default/config.yaml b/default/config.yaml index 5f3e0ce9a..694a58be6 100644 --- a/default/config.yaml +++ b/default/config.yaml @@ -108,6 +108,9 @@ enableExtensionsAutoUpdate: true # Additional model tokenizers can be downloaded on demand. # Disabling will fallback to another locally available tokenizer. enableDownloadableTokenizers: true +# Vector storage settings +vectors: + enableModelScopes: false # Extension settings extras: # Disables automatic model download from HuggingFace diff --git a/public/scripts/extensions/vectors/index.js b/public/scripts/extensions/vectors/index.js index fb97a3d1a..1b815612e 100644 --- a/public/scripts/extensions/vectors/index.js +++ b/public/scripts/extensions/vectors/index.js @@ -718,7 +718,7 @@ async function getQueryText(chat, initiator) { async function getSavedHashes(collectionId) { const response = await fetch('/api/vector/list', { method: 'POST', - headers: getRequestHeaders(), + headers: getVectorHeaders(), body: JSON.stringify({ collectionId: collectionId, source: settings.source, @@ -737,25 +737,43 @@ function getVectorHeaders() { const headers = getRequestHeaders(); switch (settings.source) { case 'extras': - addExtrasHeaders(headers); + Object.assign(headers, { + 'X-Extras-Url': extension_settings.apiUrl, + 'X-Extras-Key': extension_settings.apiKey, + }); break; case 'togetherai': - addTogetherAiHeaders(headers); + Object.assign(headers, { + 'X-Togetherai-Model': extension_settings.vectors.togetherai_model, + }); break; case 'openai': - addOpenAiHeaders(headers); + Object.assign(headers, { + 'X-OpenAI-Model': extension_settings.vectors.openai_model, + }); break; case 'cohere': - addCohereHeaders(headers); + Object.assign(headers, { + 'X-Cohere-Model': extension_settings.vectors.cohere_model, + }); break; case 'ollama': - addOllamaHeaders(headers); + Object.assign(headers, { + 'X-Ollama-Model': extension_settings.vectors.ollama_model, + 'X-Ollama-URL': textgenerationwebui_settings.server_urls[textgen_types.OLLAMA], + 'X-Ollama-Keep': !!extension_settings.vectors.ollama_keep, + }); break; case 'llamacpp': - addLlamaCppHeaders(headers); + Object.assign(headers, { + 'X-LlamaCpp-URL': textgenerationwebui_settings.server_urls[textgen_types.LLAMACPP], + }); break; case 'vllm': - addVllmHeaders(headers); + Object.assign(headers, { + 'X-Vllm-URL': textgenerationwebui_settings.server_urls[textgen_types.VLLM], + 'X-Vllm-Model': extension_settings.vectors.vllm_model, + }); break; default: break; @@ -763,81 +781,6 @@ function getVectorHeaders() { return headers; } -/** - * Add headers for the Extras API source. - * @param {object} headers Headers object - */ -function addExtrasHeaders(headers) { - console.log(`Vector source is extras, populating API URL: ${extension_settings.apiUrl}`); - Object.assign(headers, { - 'X-Extras-Url': extension_settings.apiUrl, - 'X-Extras-Key': extension_settings.apiKey, - }); -} - -/** - * Add headers for the TogetherAI API source. - * @param {object} headers Headers object - */ -function addTogetherAiHeaders(headers) { - Object.assign(headers, { - 'X-Togetherai-Model': extension_settings.vectors.togetherai_model, - }); -} - -/** - * Add headers for the OpenAI API source. - * @param {object} headers Header object - */ -function addOpenAiHeaders(headers) { - Object.assign(headers, { - 'X-OpenAI-Model': extension_settings.vectors.openai_model, - }); -} - -/** - * Add headers for the Cohere API source. - * @param {object} headers Header object - */ -function addCohereHeaders(headers) { - Object.assign(headers, { - 'X-Cohere-Model': extension_settings.vectors.cohere_model, - }); -} - -/** - * Add headers for the Ollama API source. - * @param {object} headers Header object - */ -function addOllamaHeaders(headers) { - Object.assign(headers, { - 'X-Ollama-Model': extension_settings.vectors.ollama_model, - 'X-Ollama-URL': textgenerationwebui_settings.server_urls[textgen_types.OLLAMA], - 'X-Ollama-Keep': !!extension_settings.vectors.ollama_keep, - }); -} - -/** - * Add headers for the LlamaCpp API source. - * @param {object} headers Header object - */ -function addLlamaCppHeaders(headers) { - Object.assign(headers, { - 'X-LlamaCpp-URL': textgenerationwebui_settings.server_urls[textgen_types.LLAMACPP], - }); -} - -/** - * Add headers for the VLLM API source. - * @param {object} headers Header object - */ -function addVllmHeaders(headers) { - Object.assign(headers, { - 'X-Vllm-URL': textgenerationwebui_settings.server_urls[textgen_types.VLLM], - 'X-Vllm-Model': extension_settings.vectors.vllm_model, - }); -} - /** * Inserts vector items into a collection * @param {string} collectionId - The collection to insert into @@ -901,7 +844,7 @@ function throwIfSourceInvalid() { async function deleteVectorItems(collectionId, hashes) { const response = await fetch('/api/vector/delete', { method: 'POST', - headers: getRequestHeaders(), + headers: getVectorHeaders(), body: JSON.stringify({ collectionId: collectionId, hashes: hashes, @@ -987,7 +930,7 @@ async function purgeFileVectorIndex(fileUrl) { const response = await fetch('/api/vector/purge', { method: 'POST', - headers: getRequestHeaders(), + headers: getVectorHeaders(), body: JSON.stringify({ collectionId: collectionId, }), @@ -1016,7 +959,7 @@ async function purgeVectorIndex(collectionId) { const response = await fetch('/api/vector/purge', { method: 'POST', - headers: getRequestHeaders(), + headers: getVectorHeaders(), body: JSON.stringify({ collectionId: collectionId, }), @@ -1041,7 +984,7 @@ async function purgeAllVectorIndexes() { try { const response = await fetch('/api/vector/purge-all', { method: 'POST', - headers: getRequestHeaders(), + headers: getVectorHeaders(), }); if (!response.ok) { @@ -1056,6 +999,25 @@ async function purgeAllVectorIndexes() { } } +async function isModelScopesEnabled() { + try { + const response = await fetch('/api/vector/scopes-enabled', { + method: 'GET', + headers: getVectorHeaders(), + }); + + if (!response.ok) { + return false; + } + + const data = await response.json(); + return data?.enabled ?? false; + } catch (error) { + console.error('Vectors: Failed to check model scopes', error); + return false; + } +} + function toggleSettings() { $('#vectors_files_settings').toggle(!!settings.enabled_files); $('#vectors_chats_settings').toggle(!!settings.enabled_chats); @@ -1320,6 +1282,7 @@ jQuery(async () => { } Object.assign(settings, extension_settings.vectors); + const scopesEnabled = await isModelScopesEnabled(); // Migrate from TensorFlow to Transformers settings.source = settings.source !== 'local' ? settings.source : 'transformers'; @@ -1371,31 +1334,31 @@ jQuery(async () => { saveSettingsDebounced(); }); $('#vectors_togetherai_model').val(settings.togetherai_model).on('change', () => { - $('#vectors_modelWarning').show(); + !scopesEnabled && $('#vectors_modelWarning').show(); settings.togetherai_model = String($('#vectors_togetherai_model').val()); Object.assign(extension_settings.vectors, settings); saveSettingsDebounced(); }); $('#vectors_openai_model').val(settings.openai_model).on('change', () => { - $('#vectors_modelWarning').show(); + !scopesEnabled && $('#vectors_modelWarning').show(); settings.openai_model = String($('#vectors_openai_model').val()); Object.assign(extension_settings.vectors, settings); saveSettingsDebounced(); }); $('#vectors_cohere_model').val(settings.cohere_model).on('change', () => { - $('#vectors_modelWarning').show(); + !scopesEnabled && $('#vectors_modelWarning').show(); settings.cohere_model = String($('#vectors_cohere_model').val()); Object.assign(extension_settings.vectors, settings); saveSettingsDebounced(); }); $('#vectors_ollama_model').val(settings.ollama_model).on('input', () => { - $('#vectors_modelWarning').show(); + !scopesEnabled && $('#vectors_modelWarning').show(); settings.ollama_model = String($('#vectors_ollama_model').val()); Object.assign(extension_settings.vectors, settings); saveSettingsDebounced(); }); $('#vectors_vllm_model').val(settings.vllm_model).on('input', () => { - $('#vectors_modelWarning').show(); + !scopesEnabled && $('#vectors_modelWarning').show(); settings.vllm_model = String($('#vectors_vllm_model').val()); Object.assign(extension_settings.vectors, settings); saveSettingsDebounced(); diff --git a/public/scripts/extensions/vectors/settings.html b/public/scripts/extensions/vectors/settings.html index 0dc626e53..f1e73016e 100644 --- a/public/scripts/extensions/vectors/settings.html +++ b/public/scripts/extensions/vectors/settings.html @@ -98,8 +98,9 @@ - - It is recommended to purge vectors when changing the model mid-chat. Otherwise, it will lead to sub-par results. + + Set vectors.enableModelScopes to true in config.yaml to switch between vectorization models without needing to purge existing vectors. + This option will soon be enabled by default. diff --git a/src/endpoints/vectors.js b/src/endpoints/vectors.js index 38f74f7d8..790b5693d 100644 --- a/src/endpoints/vectors.js +++ b/src/endpoints/vectors.js @@ -4,6 +4,7 @@ const fs = require('fs'); const express = require('express'); const sanitize = require('sanitize-filename'); const { jsonParser } = require('../express-common'); +const { getConfigValue, color } = require('../util'); // Don't forget to add new sources to the SOURCES array const SOURCES = [ @@ -109,19 +110,95 @@ async function getBatchVector(source, sourceSettings, texts, isQuery, directorie return results; } +/** + * Extracts settings for the vectorization sources from the HTTP request headers. + * @param {string} source - Which source to extract settings for. + * @param {object} request - The HTTP request object. + * @returns {object} - An object that can be used as `sourceSettings` in functions that take that parameter. + */ +function getSourceSettings(source, request) { + switch (source) { + case 'togetherai': + return { + model: String(request.headers['x-togetherai-model']), + }; + case 'openai': + return { + model: String(request.headers['x-openai-model']), + }; + case 'cohere': + return { + model: String(request.headers['x-cohere-model']), + }; + case 'llamacpp': + return { + apiUrl: String(request.headers['x-llamacpp-url']), + }; + case 'vllm': + return { + apiUrl: String(request.headers['x-vllm-url']), + model: String(request.headers['x-vllm-model']), + }; + case 'ollama': + return { + apiUrl: String(request.headers['x-ollama-url']), + model: String(request.headers['x-ollama-model']), + keep: Boolean(request.headers['x-ollama-keep']), + }; + case 'extras': + return { + extrasUrl: String(request.headers['x-extras-url']), + extrasKey: String(request.headers['x-extras-key']), + }; + case 'local': + return { + model: getConfigValue('extras.embeddingModel', ''), + }; + case 'palm': + return { + // TODO: Add support for multiple models + model: 'text-embedding-004', + }; + default: + return {}; + } +} + +/** + * Gets the model scope for the source. + * @param {object} sourceSettings - The settings for the source + * @returns {string} The model scope for the source + */ +function getModelScope(sourceSettings) { + const scopesEnabled = getConfigValue('vectors.enableModelScopes', false); + const warningShown = global.process.env.VECTORS_MODEL_SCOPE_WARNING_SHOWN === 'true'; + + if (!scopesEnabled && !warningShown) { + console.log(); + console.warn(color.red('[DEPRECATION NOTICE]'), 'Model scopes for Vectore Storage are disabled, but will soon be required.'); + console.log(`To enable model scopes, set the ${color.cyan('vectors.enableModelScopes')} in config.yaml to ${color.green(true)}.`); + console.log('This message won\'t be shown again in the current session.'); + console.log(); + global.process.env.VECTORS_MODEL_SCOPE_WARNING_SHOWN = 'true'; + } + + return scopesEnabled ? (sourceSettings?.model || '') : ''; +} + /** * Gets the index for the vector collection * @param {import('../users').UserDirectoryList} directories - User directories * @param {string} collectionId - The collection ID * @param {string} source - The source of the vector - * @param {boolean} create - Whether to create the index if it doesn't exist + * @param {object} sourceSettings - The model for the source * @returns {Promise} - The index for the collection */ -async function getIndex(directories, collectionId, source, create = true) { - const pathToFile = path.join(directories.vectors, sanitize(source), sanitize(collectionId)); +async function getIndex(directories, collectionId, source, sourceSettings) { + const model = getModelScope(sourceSettings); + const pathToFile = path.join(directories.vectors, sanitize(source), sanitize(collectionId), sanitize(model)); const store = new vectra.LocalIndex(pathToFile); - if (create && !await store.isIndexCreated()) { + if (!await store.isIndexCreated()) { await store.createIndex(); } @@ -137,7 +214,7 @@ async function getIndex(directories, collectionId, source, create = true) { * @param {{ hash: number; text: string; index: number; }[]} items - The items to insert */ async function insertVectorItems(directories, collectionId, source, sourceSettings, items) { - const store = await getIndex(directories, collectionId, source); + const store = await getIndex(directories, collectionId, source, sourceSettings); await store.beginUpdate(); @@ -157,10 +234,11 @@ async function insertVectorItems(directories, collectionId, source, sourceSettin * @param {import('../users').UserDirectoryList} directories - User directories * @param {string} collectionId - The collection ID * @param {string} source - The source of the vector + * @param {Object} sourceSettings - Settings for the source, if it needs any * @returns {Promise} - The hashes of the items in the collection */ -async function getSavedHashes(directories, collectionId, source) { - const store = await getIndex(directories, collectionId, source); +async function getSavedHashes(directories, collectionId, source, sourceSettings) { + const store = await getIndex(directories, collectionId, source, sourceSettings); const items = await store.listItems(); const hashes = items.map(x => Number(x.metadata.hash)); @@ -173,10 +251,11 @@ async function getSavedHashes(directories, collectionId, source) { * @param {import('../users').UserDirectoryList} directories - User directories * @param {string} collectionId - The collection ID * @param {string} source - The source of the vector + * @param {Object} sourceSettings - Settings for the source, if it needs any * @param {number[]} hashes - The hashes of the items to delete */ -async function deleteVectorItems(directories, collectionId, source, hashes) { - const store = await getIndex(directories, collectionId, source); +async function deleteVectorItems(directories, collectionId, source, sourceSettings, hashes) { + const store = await getIndex(directories, collectionId, source, sourceSettings); const items = await store.listItemsByMetadata({ hash: { '$in': hashes } }); await store.beginUpdate(); @@ -200,7 +279,7 @@ async function deleteVectorItems(directories, collectionId, source, hashes) { * @returns {Promise<{hashes: number[], metadata: object[]}>} - The metadata of the items that match the search text */ async function queryCollection(directories, collectionId, source, sourceSettings, searchText, topK, threshold) { - const store = await getIndex(directories, collectionId, source); + const store = await getIndex(directories, collectionId, source, sourceSettings); const vector = await getVector(source, sourceSettings, searchText, true, directories); const result = await store.queryItems(vector, topK); @@ -226,7 +305,7 @@ async function multiQueryCollection(directories, collectionIds, source, sourceSe const results = []; for (const collectionId of collectionIds) { - const store = await getIndex(directories, collectionId, source); + const store = await getIndex(directories, collectionId, source, sourceSettings); const result = await store.queryItems(vector, topK); results.push(...result.map(result => ({ collectionId, result }))); } @@ -254,71 +333,6 @@ async function multiQueryCollection(directories, collectionIds, source, sourceSe return groupedResults; } -/** - * Extracts settings for the vectorization sources from the HTTP request headers. - * @param {string} source - Which source to extract settings for. - * @param {object} request - The HTTP request object. - * @returns {object} - An object that can be used as `sourceSettings` in functions that take that parameter. - */ -function getSourceSettings(source, request) { - if (source === 'togetherai') { - const model = String(request.headers['x-togetherai-model']); - - return { - model: model, - }; - } else if (source === 'openai') { - const model = String(request.headers['x-openai-model']); - - return { - model: model, - }; - } else if (source === 'cohere') { - const model = String(request.headers['x-cohere-model']); - - return { - model: model, - }; - } else if (source === 'llamacpp') { - const apiUrl = String(request.headers['x-llamacpp-url']); - - return { - apiUrl: apiUrl, - }; - } else if (source === 'vllm') { - const apiUrl = String(request.headers['x-vllm-url']); - const model = String(request.headers['x-vllm-model']); - - return { - apiUrl: apiUrl, - model: model, - }; - } else if (source === 'ollama') { - const apiUrl = String(request.headers['x-ollama-url']); - const model = String(request.headers['x-ollama-model']); - const keep = Boolean(request.headers['x-ollama-keep']); - - return { - apiUrl: apiUrl, - model: model, - keep: keep, - }; - } else { - // Extras API settings to connect to the Extras embeddings provider - let extrasUrl = ''; - let extrasKey = ''; - if (source === 'extras') { - extrasUrl = String(request.headers['x-extras-url']); - extrasKey = String(request.headers['x-extras-key']); - } - - return { - extrasUrl: extrasUrl, - extrasKey: extrasKey, - }; - } -} - /** * Performs a request to regenerate the index if it is corrupted. * @param {import('express').Request} req Express request object @@ -330,9 +344,10 @@ async function regenerateCorruptedIndexErrorHandler(req, res, error) { if (error instanceof SyntaxError && !req.query.regenerated) { const collectionId = String(req.body.collectionId); const source = String(req.body.source) || 'transformers'; + const sourceSettings = getSourceSettings(source, req); if (collectionId && source) { - const index = await getIndex(req.user.directories, collectionId, source, false); + const index = await getIndex(req.user.directories, collectionId, source, sourceSettings); const exists = await index.isIndexCreated(); if (exists) { @@ -350,6 +365,11 @@ async function regenerateCorruptedIndexErrorHandler(req, res, error) { const router = express.Router(); +router.get('/scopes-enabled', (_req, res) => { + const scopesEnabled = getConfigValue('vectors.enableModelScopes', false); + return res.json({ enabled: scopesEnabled }); +}); + router.post('/query', jsonParser, async (req, res) => { try { if (!req.body.collectionId || !req.body.searchText) { @@ -416,8 +436,9 @@ router.post('/list', jsonParser, async (req, res) => { const collectionId = String(req.body.collectionId); const source = String(req.body.source) || 'transformers'; + const sourceSettings = getSourceSettings(source, req); - const hashes = await getSavedHashes(req.user.directories, collectionId, source); + const hashes = await getSavedHashes(req.user.directories, collectionId, source, sourceSettings); return res.json(hashes); } catch (error) { return regenerateCorruptedIndexErrorHandler(req, res, error); @@ -433,8 +454,9 @@ router.post('/delete', jsonParser, async (req, res) => { const collectionId = String(req.body.collectionId); const hashes = req.body.hashes.map(x => Number(x)); const source = String(req.body.source) || 'transformers'; + const sourceSettings = getSourceSettings(source, req); - await deleteVectorItems(req.user.directories, collectionId, source, hashes); + await deleteVectorItems(req.user.directories, collectionId, source, sourceSettings, hashes); return res.sendStatus(200); } catch (error) { return regenerateCorruptedIndexErrorHandler(req, res, error); @@ -468,17 +490,12 @@ router.post('/purge', jsonParser, async (req, res) => { const collectionId = String(req.body.collectionId); for (const source of SOURCES) { - const index = await getIndex(req.user.directories, collectionId, source, false); - - const exists = await index.isIndexCreated(); - - if (!exists) { + const sourcePath = path.join(req.user.directories.vectors, sanitize(source), sanitize(collectionId)); + if (!fs.existsSync(sourcePath)) { continue; } - - const path = index.folderPath; - await index.deleteIndex(); - console.log(`Deleted vector index at ${path}`); + await fs.promises.rm(sourcePath, { recursive: true }); + console.log(`Deleted vector index at ${sourcePath}`); } return res.sendStatus(200);