diff --git a/default/config.yaml b/default/config.yaml
index 5f3e0ce9a..694a58be6 100644
--- a/default/config.yaml
+++ b/default/config.yaml
@@ -108,6 +108,9 @@ enableExtensionsAutoUpdate: true
# Additional model tokenizers can be downloaded on demand.
# Disabling will fallback to another locally available tokenizer.
enableDownloadableTokenizers: true
+# Vector storage settings
+vectors:
+ enableModelScopes: false
# Extension settings
extras:
# Disables automatic model download from HuggingFace
diff --git a/public/scripts/extensions/vectors/index.js b/public/scripts/extensions/vectors/index.js
index fb97a3d1a..1b815612e 100644
--- a/public/scripts/extensions/vectors/index.js
+++ b/public/scripts/extensions/vectors/index.js
@@ -718,7 +718,7 @@ async function getQueryText(chat, initiator) {
async function getSavedHashes(collectionId) {
const response = await fetch('/api/vector/list', {
method: 'POST',
- headers: getRequestHeaders(),
+ headers: getVectorHeaders(),
body: JSON.stringify({
collectionId: collectionId,
source: settings.source,
@@ -737,25 +737,43 @@ function getVectorHeaders() {
const headers = getRequestHeaders();
switch (settings.source) {
case 'extras':
- addExtrasHeaders(headers);
+ Object.assign(headers, {
+ 'X-Extras-Url': extension_settings.apiUrl,
+ 'X-Extras-Key': extension_settings.apiKey,
+ });
break;
case 'togetherai':
- addTogetherAiHeaders(headers);
+ Object.assign(headers, {
+ 'X-Togetherai-Model': extension_settings.vectors.togetherai_model,
+ });
break;
case 'openai':
- addOpenAiHeaders(headers);
+ Object.assign(headers, {
+ 'X-OpenAI-Model': extension_settings.vectors.openai_model,
+ });
break;
case 'cohere':
- addCohereHeaders(headers);
+ Object.assign(headers, {
+ 'X-Cohere-Model': extension_settings.vectors.cohere_model,
+ });
break;
case 'ollama':
- addOllamaHeaders(headers);
+ Object.assign(headers, {
+ 'X-Ollama-Model': extension_settings.vectors.ollama_model,
+ 'X-Ollama-URL': textgenerationwebui_settings.server_urls[textgen_types.OLLAMA],
+ 'X-Ollama-Keep': !!extension_settings.vectors.ollama_keep,
+ });
break;
case 'llamacpp':
- addLlamaCppHeaders(headers);
+ Object.assign(headers, {
+ 'X-LlamaCpp-URL': textgenerationwebui_settings.server_urls[textgen_types.LLAMACPP],
+ });
break;
case 'vllm':
- addVllmHeaders(headers);
+ Object.assign(headers, {
+ 'X-Vllm-URL': textgenerationwebui_settings.server_urls[textgen_types.VLLM],
+ 'X-Vllm-Model': extension_settings.vectors.vllm_model,
+ });
break;
default:
break;
@@ -763,81 +781,6 @@ function getVectorHeaders() {
return headers;
}
-/**
- * Add headers for the Extras API source.
- * @param {object} headers Headers object
- */
-function addExtrasHeaders(headers) {
- console.log(`Vector source is extras, populating API URL: ${extension_settings.apiUrl}`);
- Object.assign(headers, {
- 'X-Extras-Url': extension_settings.apiUrl,
- 'X-Extras-Key': extension_settings.apiKey,
- });
-}
-
-/**
- * Add headers for the TogetherAI API source.
- * @param {object} headers Headers object
- */
-function addTogetherAiHeaders(headers) {
- Object.assign(headers, {
- 'X-Togetherai-Model': extension_settings.vectors.togetherai_model,
- });
-}
-
-/**
- * Add headers for the OpenAI API source.
- * @param {object} headers Header object
- */
-function addOpenAiHeaders(headers) {
- Object.assign(headers, {
- 'X-OpenAI-Model': extension_settings.vectors.openai_model,
- });
-}
-
-/**
- * Add headers for the Cohere API source.
- * @param {object} headers Header object
- */
-function addCohereHeaders(headers) {
- Object.assign(headers, {
- 'X-Cohere-Model': extension_settings.vectors.cohere_model,
- });
-}
-
-/**
- * Add headers for the Ollama API source.
- * @param {object} headers Header object
- */
-function addOllamaHeaders(headers) {
- Object.assign(headers, {
- 'X-Ollama-Model': extension_settings.vectors.ollama_model,
- 'X-Ollama-URL': textgenerationwebui_settings.server_urls[textgen_types.OLLAMA],
- 'X-Ollama-Keep': !!extension_settings.vectors.ollama_keep,
- });
-}
-
-/**
- * Add headers for the LlamaCpp API source.
- * @param {object} headers Header object
- */
-function addLlamaCppHeaders(headers) {
- Object.assign(headers, {
- 'X-LlamaCpp-URL': textgenerationwebui_settings.server_urls[textgen_types.LLAMACPP],
- });
-}
-
-/**
- * Add headers for the VLLM API source.
- * @param {object} headers Header object
- */
-function addVllmHeaders(headers) {
- Object.assign(headers, {
- 'X-Vllm-URL': textgenerationwebui_settings.server_urls[textgen_types.VLLM],
- 'X-Vllm-Model': extension_settings.vectors.vllm_model,
- });
-}
-
/**
* Inserts vector items into a collection
* @param {string} collectionId - The collection to insert into
@@ -901,7 +844,7 @@ function throwIfSourceInvalid() {
async function deleteVectorItems(collectionId, hashes) {
const response = await fetch('/api/vector/delete', {
method: 'POST',
- headers: getRequestHeaders(),
+ headers: getVectorHeaders(),
body: JSON.stringify({
collectionId: collectionId,
hashes: hashes,
@@ -987,7 +930,7 @@ async function purgeFileVectorIndex(fileUrl) {
const response = await fetch('/api/vector/purge', {
method: 'POST',
- headers: getRequestHeaders(),
+ headers: getVectorHeaders(),
body: JSON.stringify({
collectionId: collectionId,
}),
@@ -1016,7 +959,7 @@ async function purgeVectorIndex(collectionId) {
const response = await fetch('/api/vector/purge', {
method: 'POST',
- headers: getRequestHeaders(),
+ headers: getVectorHeaders(),
body: JSON.stringify({
collectionId: collectionId,
}),
@@ -1041,7 +984,7 @@ async function purgeAllVectorIndexes() {
try {
const response = await fetch('/api/vector/purge-all', {
method: 'POST',
- headers: getRequestHeaders(),
+ headers: getVectorHeaders(),
});
if (!response.ok) {
@@ -1056,6 +999,25 @@ async function purgeAllVectorIndexes() {
}
}
+async function isModelScopesEnabled() {
+ try {
+ const response = await fetch('/api/vector/scopes-enabled', {
+ method: 'GET',
+ headers: getVectorHeaders(),
+ });
+
+ if (!response.ok) {
+ return false;
+ }
+
+ const data = await response.json();
+ return data?.enabled ?? false;
+ } catch (error) {
+ console.error('Vectors: Failed to check model scopes', error);
+ return false;
+ }
+}
+
function toggleSettings() {
$('#vectors_files_settings').toggle(!!settings.enabled_files);
$('#vectors_chats_settings').toggle(!!settings.enabled_chats);
@@ -1320,6 +1282,7 @@ jQuery(async () => {
}
Object.assign(settings, extension_settings.vectors);
+ const scopesEnabled = await isModelScopesEnabled();
// Migrate from TensorFlow to Transformers
settings.source = settings.source !== 'local' ? settings.source : 'transformers';
@@ -1371,31 +1334,31 @@ jQuery(async () => {
saveSettingsDebounced();
});
$('#vectors_togetherai_model').val(settings.togetherai_model).on('change', () => {
- $('#vectors_modelWarning').show();
+ !scopesEnabled && $('#vectors_modelWarning').show();
settings.togetherai_model = String($('#vectors_togetherai_model').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_openai_model').val(settings.openai_model).on('change', () => {
- $('#vectors_modelWarning').show();
+ !scopesEnabled && $('#vectors_modelWarning').show();
settings.openai_model = String($('#vectors_openai_model').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_cohere_model').val(settings.cohere_model).on('change', () => {
- $('#vectors_modelWarning').show();
+ !scopesEnabled && $('#vectors_modelWarning').show();
settings.cohere_model = String($('#vectors_cohere_model').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_ollama_model').val(settings.ollama_model).on('input', () => {
- $('#vectors_modelWarning').show();
+ !scopesEnabled && $('#vectors_modelWarning').show();
settings.ollama_model = String($('#vectors_ollama_model').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_vllm_model').val(settings.vllm_model).on('input', () => {
- $('#vectors_modelWarning').show();
+ !scopesEnabled && $('#vectors_modelWarning').show();
settings.vllm_model = String($('#vectors_vllm_model').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
diff --git a/public/scripts/extensions/vectors/settings.html b/public/scripts/extensions/vectors/settings.html
index 0dc626e53..f1e73016e 100644
--- a/public/scripts/extensions/vectors/settings.html
+++ b/public/scripts/extensions/vectors/settings.html
@@ -98,8 +98,9 @@
-
- It is recommended to purge vectors when changing the model mid-chat. Otherwise, it will lead to sub-par results.
+
+ Set vectors.enableModelScopes
to true in config.yaml to switch between vectorization models without needing to purge existing vectors.
+ This option will soon be enabled by default.
diff --git a/src/endpoints/vectors.js b/src/endpoints/vectors.js
index 38f74f7d8..790b5693d 100644
--- a/src/endpoints/vectors.js
+++ b/src/endpoints/vectors.js
@@ -4,6 +4,7 @@ const fs = require('fs');
const express = require('express');
const sanitize = require('sanitize-filename');
const { jsonParser } = require('../express-common');
+const { getConfigValue, color } = require('../util');
// Don't forget to add new sources to the SOURCES array
const SOURCES = [
@@ -109,19 +110,95 @@ async function getBatchVector(source, sourceSettings, texts, isQuery, directorie
return results;
}
+/**
+ * Extracts settings for the vectorization sources from the HTTP request headers.
+ * @param {string} source - Which source to extract settings for.
+ * @param {object} request - The HTTP request object.
+ * @returns {object} - An object that can be used as `sourceSettings` in functions that take that parameter.
+ */
+function getSourceSettings(source, request) {
+ switch (source) {
+ case 'togetherai':
+ return {
+ model: String(request.headers['x-togetherai-model']),
+ };
+ case 'openai':
+ return {
+ model: String(request.headers['x-openai-model']),
+ };
+ case 'cohere':
+ return {
+ model: String(request.headers['x-cohere-model']),
+ };
+ case 'llamacpp':
+ return {
+ apiUrl: String(request.headers['x-llamacpp-url']),
+ };
+ case 'vllm':
+ return {
+ apiUrl: String(request.headers['x-vllm-url']),
+ model: String(request.headers['x-vllm-model']),
+ };
+ case 'ollama':
+ return {
+ apiUrl: String(request.headers['x-ollama-url']),
+ model: String(request.headers['x-ollama-model']),
+ keep: Boolean(request.headers['x-ollama-keep']),
+ };
+ case 'extras':
+ return {
+ extrasUrl: String(request.headers['x-extras-url']),
+ extrasKey: String(request.headers['x-extras-key']),
+ };
+ case 'local':
+ return {
+ model: getConfigValue('extras.embeddingModel', ''),
+ };
+ case 'palm':
+ return {
+ // TODO: Add support for multiple models
+ model: 'text-embedding-004',
+ };
+ default:
+ return {};
+ }
+}
+
+/**
+ * Gets the model scope for the source.
+ * @param {object} sourceSettings - The settings for the source
+ * @returns {string} The model scope for the source
+ */
+function getModelScope(sourceSettings) {
+ const scopesEnabled = getConfigValue('vectors.enableModelScopes', false);
+ const warningShown = global.process.env.VECTORS_MODEL_SCOPE_WARNING_SHOWN === 'true';
+
+ if (!scopesEnabled && !warningShown) {
+ console.log();
+ console.warn(color.red('[DEPRECATION NOTICE]'), 'Model scopes for Vectore Storage are disabled, but will soon be required.');
+ console.log(`To enable model scopes, set the ${color.cyan('vectors.enableModelScopes')} in config.yaml to ${color.green(true)}.`);
+ console.log('This message won\'t be shown again in the current session.');
+ console.log();
+ global.process.env.VECTORS_MODEL_SCOPE_WARNING_SHOWN = 'true';
+ }
+
+ return scopesEnabled ? (sourceSettings?.model || '') : '';
+}
+
/**
* Gets the index for the vector collection
* @param {import('../users').UserDirectoryList} directories - User directories
* @param {string} collectionId - The collection ID
* @param {string} source - The source of the vector
- * @param {boolean} create - Whether to create the index if it doesn't exist
+ * @param {object} sourceSettings - The model for the source
* @returns {Promise} - The index for the collection
*/
-async function getIndex(directories, collectionId, source, create = true) {
- const pathToFile = path.join(directories.vectors, sanitize(source), sanitize(collectionId));
+async function getIndex(directories, collectionId, source, sourceSettings) {
+ const model = getModelScope(sourceSettings);
+ const pathToFile = path.join(directories.vectors, sanitize(source), sanitize(collectionId), sanitize(model));
const store = new vectra.LocalIndex(pathToFile);
- if (create && !await store.isIndexCreated()) {
+ if (!await store.isIndexCreated()) {
await store.createIndex();
}
@@ -137,7 +214,7 @@ async function getIndex(directories, collectionId, source, create = true) {
* @param {{ hash: number; text: string; index: number; }[]} items - The items to insert
*/
async function insertVectorItems(directories, collectionId, source, sourceSettings, items) {
- const store = await getIndex(directories, collectionId, source);
+ const store = await getIndex(directories, collectionId, source, sourceSettings);
await store.beginUpdate();
@@ -157,10 +234,11 @@ async function insertVectorItems(directories, collectionId, source, sourceSettin
* @param {import('../users').UserDirectoryList} directories - User directories
* @param {string} collectionId - The collection ID
* @param {string} source - The source of the vector
+ * @param {Object} sourceSettings - Settings for the source, if it needs any
* @returns {Promise} - The hashes of the items in the collection
*/
-async function getSavedHashes(directories, collectionId, source) {
- const store = await getIndex(directories, collectionId, source);
+async function getSavedHashes(directories, collectionId, source, sourceSettings) {
+ const store = await getIndex(directories, collectionId, source, sourceSettings);
const items = await store.listItems();
const hashes = items.map(x => Number(x.metadata.hash));
@@ -173,10 +251,11 @@ async function getSavedHashes(directories, collectionId, source) {
* @param {import('../users').UserDirectoryList} directories - User directories
* @param {string} collectionId - The collection ID
* @param {string} source - The source of the vector
+ * @param {Object} sourceSettings - Settings for the source, if it needs any
* @param {number[]} hashes - The hashes of the items to delete
*/
-async function deleteVectorItems(directories, collectionId, source, hashes) {
- const store = await getIndex(directories, collectionId, source);
+async function deleteVectorItems(directories, collectionId, source, sourceSettings, hashes) {
+ const store = await getIndex(directories, collectionId, source, sourceSettings);
const items = await store.listItemsByMetadata({ hash: { '$in': hashes } });
await store.beginUpdate();
@@ -200,7 +279,7 @@ async function deleteVectorItems(directories, collectionId, source, hashes) {
* @returns {Promise<{hashes: number[], metadata: object[]}>} - The metadata of the items that match the search text
*/
async function queryCollection(directories, collectionId, source, sourceSettings, searchText, topK, threshold) {
- const store = await getIndex(directories, collectionId, source);
+ const store = await getIndex(directories, collectionId, source, sourceSettings);
const vector = await getVector(source, sourceSettings, searchText, true, directories);
const result = await store.queryItems(vector, topK);
@@ -226,7 +305,7 @@ async function multiQueryCollection(directories, collectionIds, source, sourceSe
const results = [];
for (const collectionId of collectionIds) {
- const store = await getIndex(directories, collectionId, source);
+ const store = await getIndex(directories, collectionId, source, sourceSettings);
const result = await store.queryItems(vector, topK);
results.push(...result.map(result => ({ collectionId, result })));
}
@@ -254,71 +333,6 @@ async function multiQueryCollection(directories, collectionIds, source, sourceSe
return groupedResults;
}
-/**
- * Extracts settings for the vectorization sources from the HTTP request headers.
- * @param {string} source - Which source to extract settings for.
- * @param {object} request - The HTTP request object.
- * @returns {object} - An object that can be used as `sourceSettings` in functions that take that parameter.
- */
-function getSourceSettings(source, request) {
- if (source === 'togetherai') {
- const model = String(request.headers['x-togetherai-model']);
-
- return {
- model: model,
- };
- } else if (source === 'openai') {
- const model = String(request.headers['x-openai-model']);
-
- return {
- model: model,
- };
- } else if (source === 'cohere') {
- const model = String(request.headers['x-cohere-model']);
-
- return {
- model: model,
- };
- } else if (source === 'llamacpp') {
- const apiUrl = String(request.headers['x-llamacpp-url']);
-
- return {
- apiUrl: apiUrl,
- };
- } else if (source === 'vllm') {
- const apiUrl = String(request.headers['x-vllm-url']);
- const model = String(request.headers['x-vllm-model']);
-
- return {
- apiUrl: apiUrl,
- model: model,
- };
- } else if (source === 'ollama') {
- const apiUrl = String(request.headers['x-ollama-url']);
- const model = String(request.headers['x-ollama-model']);
- const keep = Boolean(request.headers['x-ollama-keep']);
-
- return {
- apiUrl: apiUrl,
- model: model,
- keep: keep,
- };
- } else {
- // Extras API settings to connect to the Extras embeddings provider
- let extrasUrl = '';
- let extrasKey = '';
- if (source === 'extras') {
- extrasUrl = String(request.headers['x-extras-url']);
- extrasKey = String(request.headers['x-extras-key']);
- }
-
- return {
- extrasUrl: extrasUrl,
- extrasKey: extrasKey,
- };
- }
-}
-
/**
* Performs a request to regenerate the index if it is corrupted.
* @param {import('express').Request} req Express request object
@@ -330,9 +344,10 @@ async function regenerateCorruptedIndexErrorHandler(req, res, error) {
if (error instanceof SyntaxError && !req.query.regenerated) {
const collectionId = String(req.body.collectionId);
const source = String(req.body.source) || 'transformers';
+ const sourceSettings = getSourceSettings(source, req);
if (collectionId && source) {
- const index = await getIndex(req.user.directories, collectionId, source, false);
+ const index = await getIndex(req.user.directories, collectionId, source, sourceSettings);
const exists = await index.isIndexCreated();
if (exists) {
@@ -350,6 +365,11 @@ async function regenerateCorruptedIndexErrorHandler(req, res, error) {
const router = express.Router();
+router.get('/scopes-enabled', (_req, res) => {
+ const scopesEnabled = getConfigValue('vectors.enableModelScopes', false);
+ return res.json({ enabled: scopesEnabled });
+});
+
router.post('/query', jsonParser, async (req, res) => {
try {
if (!req.body.collectionId || !req.body.searchText) {
@@ -416,8 +436,9 @@ router.post('/list', jsonParser, async (req, res) => {
const collectionId = String(req.body.collectionId);
const source = String(req.body.source) || 'transformers';
+ const sourceSettings = getSourceSettings(source, req);
- const hashes = await getSavedHashes(req.user.directories, collectionId, source);
+ const hashes = await getSavedHashes(req.user.directories, collectionId, source, sourceSettings);
return res.json(hashes);
} catch (error) {
return regenerateCorruptedIndexErrorHandler(req, res, error);
@@ -433,8 +454,9 @@ router.post('/delete', jsonParser, async (req, res) => {
const collectionId = String(req.body.collectionId);
const hashes = req.body.hashes.map(x => Number(x));
const source = String(req.body.source) || 'transformers';
+ const sourceSettings = getSourceSettings(source, req);
- await deleteVectorItems(req.user.directories, collectionId, source, hashes);
+ await deleteVectorItems(req.user.directories, collectionId, source, sourceSettings, hashes);
return res.sendStatus(200);
} catch (error) {
return regenerateCorruptedIndexErrorHandler(req, res, error);
@@ -468,17 +490,12 @@ router.post('/purge', jsonParser, async (req, res) => {
const collectionId = String(req.body.collectionId);
for (const source of SOURCES) {
- const index = await getIndex(req.user.directories, collectionId, source, false);
-
- const exists = await index.isIndexCreated();
-
- if (!exists) {
+ const sourcePath = path.join(req.user.directories.vectors, sanitize(source), sanitize(collectionId));
+ if (!fs.existsSync(sourcePath)) {
continue;
}
-
- const path = index.folderPath;
- await index.deleteIndex();
- console.log(`Deleted vector index at ${path}`);
+ await fs.promises.rm(sourcePath, { recursive: true });
+ console.log(`Deleted vector index at ${sourcePath}`);
}
return res.sendStatus(200);