Implement collection scopes for vector models (#2846)

* Implement collection scopes for vector models

* Update makersuite-vectors.js to use Gemini API text-embedding-004 model

* Add model scope for Google vectors

* Fix purge log

* Refactor header setting

* Fix typo

* Only display UI warning if scopes disabled

* Remove i18n attribute

---------

Co-authored-by: ceruleandeep <83318388+ceruleandeep@users.noreply.github.com>
This commit is contained in:
Cohee
2024-09-16 09:29:39 +03:00
committed by GitHub
parent 7eb7204dc7
commit 9ef3385255
4 changed files with 166 additions and 182 deletions

View File

@@ -4,6 +4,7 @@ const fs = require('fs');
const express = require('express');
const sanitize = require('sanitize-filename');
const { jsonParser } = require('../express-common');
const { getConfigValue, color } = require('../util');
// Don't forget to add new sources to the SOURCES array
const SOURCES = [
@@ -109,19 +110,95 @@ async function getBatchVector(source, sourceSettings, texts, isQuery, directorie
return results;
}
/**
* Extracts settings for the vectorization sources from the HTTP request headers.
* @param {string} source - Which source to extract settings for.
* @param {object} request - The HTTP request object.
* @returns {object} - An object that can be used as `sourceSettings` in functions that take that parameter.
*/
function getSourceSettings(source, request) {
switch (source) {
case 'togetherai':
return {
model: String(request.headers['x-togetherai-model']),
};
case 'openai':
return {
model: String(request.headers['x-openai-model']),
};
case 'cohere':
return {
model: String(request.headers['x-cohere-model']),
};
case 'llamacpp':
return {
apiUrl: String(request.headers['x-llamacpp-url']),
};
case 'vllm':
return {
apiUrl: String(request.headers['x-vllm-url']),
model: String(request.headers['x-vllm-model']),
};
case 'ollama':
return {
apiUrl: String(request.headers['x-ollama-url']),
model: String(request.headers['x-ollama-model']),
keep: Boolean(request.headers['x-ollama-keep']),
};
case 'extras':
return {
extrasUrl: String(request.headers['x-extras-url']),
extrasKey: String(request.headers['x-extras-key']),
};
case 'local':
return {
model: getConfigValue('extras.embeddingModel', ''),
};
case 'palm':
return {
// TODO: Add support for multiple models
model: 'text-embedding-004',
};
default:
return {};
}
}
/**
* Gets the model scope for the source.
* @param {object} sourceSettings - The settings for the source
* @returns {string} The model scope for the source
*/
function getModelScope(sourceSettings) {
const scopesEnabled = getConfigValue('vectors.enableModelScopes', false);
const warningShown = global.process.env.VECTORS_MODEL_SCOPE_WARNING_SHOWN === 'true';
if (!scopesEnabled && !warningShown) {
console.log();
console.warn(color.red('[DEPRECATION NOTICE]'), 'Model scopes for Vectore Storage are disabled, but will soon be required.');
console.log(`To enable model scopes, set the ${color.cyan('vectors.enableModelScopes')} in config.yaml to ${color.green(true)}.`);
console.log('This message won\'t be shown again in the current session.');
console.log();
global.process.env.VECTORS_MODEL_SCOPE_WARNING_SHOWN = 'true';
}
return scopesEnabled ? (sourceSettings?.model || '') : '';
}
/**
* Gets the index for the vector collection
* @param {import('../users').UserDirectoryList} directories - User directories
* @param {string} collectionId - The collection ID
* @param {string} source - The source of the vector
* @param {boolean} create - Whether to create the index if it doesn't exist
* @param {object} sourceSettings - The model for the source
* @returns {Promise<vectra.LocalIndex>} - The index for the collection
*/
async function getIndex(directories, collectionId, source, create = true) {
const pathToFile = path.join(directories.vectors, sanitize(source), sanitize(collectionId));
async function getIndex(directories, collectionId, source, sourceSettings) {
const model = getModelScope(sourceSettings);
const pathToFile = path.join(directories.vectors, sanitize(source), sanitize(collectionId), sanitize(model));
const store = new vectra.LocalIndex(pathToFile);
if (create && !await store.isIndexCreated()) {
if (!await store.isIndexCreated()) {
await store.createIndex();
}
@@ -137,7 +214,7 @@ async function getIndex(directories, collectionId, source, create = true) {
* @param {{ hash: number; text: string; index: number; }[]} items - The items to insert
*/
async function insertVectorItems(directories, collectionId, source, sourceSettings, items) {
const store = await getIndex(directories, collectionId, source);
const store = await getIndex(directories, collectionId, source, sourceSettings);
await store.beginUpdate();
@@ -157,10 +234,11 @@ async function insertVectorItems(directories, collectionId, source, sourceSettin
* @param {import('../users').UserDirectoryList} directories - User directories
* @param {string} collectionId - The collection ID
* @param {string} source - The source of the vector
* @param {Object} sourceSettings - Settings for the source, if it needs any
* @returns {Promise<number[]>} - The hashes of the items in the collection
*/
async function getSavedHashes(directories, collectionId, source) {
const store = await getIndex(directories, collectionId, source);
async function getSavedHashes(directories, collectionId, source, sourceSettings) {
const store = await getIndex(directories, collectionId, source, sourceSettings);
const items = await store.listItems();
const hashes = items.map(x => Number(x.metadata.hash));
@@ -173,10 +251,11 @@ async function getSavedHashes(directories, collectionId, source) {
* @param {import('../users').UserDirectoryList} directories - User directories
* @param {string} collectionId - The collection ID
* @param {string} source - The source of the vector
* @param {Object} sourceSettings - Settings for the source, if it needs any
* @param {number[]} hashes - The hashes of the items to delete
*/
async function deleteVectorItems(directories, collectionId, source, hashes) {
const store = await getIndex(directories, collectionId, source);
async function deleteVectorItems(directories, collectionId, source, sourceSettings, hashes) {
const store = await getIndex(directories, collectionId, source, sourceSettings);
const items = await store.listItemsByMetadata({ hash: { '$in': hashes } });
await store.beginUpdate();
@@ -200,7 +279,7 @@ async function deleteVectorItems(directories, collectionId, source, hashes) {
* @returns {Promise<{hashes: number[], metadata: object[]}>} - The metadata of the items that match the search text
*/
async function queryCollection(directories, collectionId, source, sourceSettings, searchText, topK, threshold) {
const store = await getIndex(directories, collectionId, source);
const store = await getIndex(directories, collectionId, source, sourceSettings);
const vector = await getVector(source, sourceSettings, searchText, true, directories);
const result = await store.queryItems(vector, topK);
@@ -226,7 +305,7 @@ async function multiQueryCollection(directories, collectionIds, source, sourceSe
const results = [];
for (const collectionId of collectionIds) {
const store = await getIndex(directories, collectionId, source);
const store = await getIndex(directories, collectionId, source, sourceSettings);
const result = await store.queryItems(vector, topK);
results.push(...result.map(result => ({ collectionId, result })));
}
@@ -254,71 +333,6 @@ async function multiQueryCollection(directories, collectionIds, source, sourceSe
return groupedResults;
}
/**
* Extracts settings for the vectorization sources from the HTTP request headers.
* @param {string} source - Which source to extract settings for.
* @param {object} request - The HTTP request object.
* @returns {object} - An object that can be used as `sourceSettings` in functions that take that parameter.
*/
function getSourceSettings(source, request) {
if (source === 'togetherai') {
const model = String(request.headers['x-togetherai-model']);
return {
model: model,
};
} else if (source === 'openai') {
const model = String(request.headers['x-openai-model']);
return {
model: model,
};
} else if (source === 'cohere') {
const model = String(request.headers['x-cohere-model']);
return {
model: model,
};
} else if (source === 'llamacpp') {
const apiUrl = String(request.headers['x-llamacpp-url']);
return {
apiUrl: apiUrl,
};
} else if (source === 'vllm') {
const apiUrl = String(request.headers['x-vllm-url']);
const model = String(request.headers['x-vllm-model']);
return {
apiUrl: apiUrl,
model: model,
};
} else if (source === 'ollama') {
const apiUrl = String(request.headers['x-ollama-url']);
const model = String(request.headers['x-ollama-model']);
const keep = Boolean(request.headers['x-ollama-keep']);
return {
apiUrl: apiUrl,
model: model,
keep: keep,
};
} else {
// Extras API settings to connect to the Extras embeddings provider
let extrasUrl = '';
let extrasKey = '';
if (source === 'extras') {
extrasUrl = String(request.headers['x-extras-url']);
extrasKey = String(request.headers['x-extras-key']);
}
return {
extrasUrl: extrasUrl,
extrasKey: extrasKey,
};
}
}
/**
* Performs a request to regenerate the index if it is corrupted.
* @param {import('express').Request} req Express request object
@@ -330,9 +344,10 @@ async function regenerateCorruptedIndexErrorHandler(req, res, error) {
if (error instanceof SyntaxError && !req.query.regenerated) {
const collectionId = String(req.body.collectionId);
const source = String(req.body.source) || 'transformers';
const sourceSettings = getSourceSettings(source, req);
if (collectionId && source) {
const index = await getIndex(req.user.directories, collectionId, source, false);
const index = await getIndex(req.user.directories, collectionId, source, sourceSettings);
const exists = await index.isIndexCreated();
if (exists) {
@@ -350,6 +365,11 @@ async function regenerateCorruptedIndexErrorHandler(req, res, error) {
const router = express.Router();
router.get('/scopes-enabled', (_req, res) => {
const scopesEnabled = getConfigValue('vectors.enableModelScopes', false);
return res.json({ enabled: scopesEnabled });
});
router.post('/query', jsonParser, async (req, res) => {
try {
if (!req.body.collectionId || !req.body.searchText) {
@@ -416,8 +436,9 @@ router.post('/list', jsonParser, async (req, res) => {
const collectionId = String(req.body.collectionId);
const source = String(req.body.source) || 'transformers';
const sourceSettings = getSourceSettings(source, req);
const hashes = await getSavedHashes(req.user.directories, collectionId, source);
const hashes = await getSavedHashes(req.user.directories, collectionId, source, sourceSettings);
return res.json(hashes);
} catch (error) {
return regenerateCorruptedIndexErrorHandler(req, res, error);
@@ -433,8 +454,9 @@ router.post('/delete', jsonParser, async (req, res) => {
const collectionId = String(req.body.collectionId);
const hashes = req.body.hashes.map(x => Number(x));
const source = String(req.body.source) || 'transformers';
const sourceSettings = getSourceSettings(source, req);
await deleteVectorItems(req.user.directories, collectionId, source, hashes);
await deleteVectorItems(req.user.directories, collectionId, source, sourceSettings, hashes);
return res.sendStatus(200);
} catch (error) {
return regenerateCorruptedIndexErrorHandler(req, res, error);
@@ -468,17 +490,12 @@ router.post('/purge', jsonParser, async (req, res) => {
const collectionId = String(req.body.collectionId);
for (const source of SOURCES) {
const index = await getIndex(req.user.directories, collectionId, source, false);
const exists = await index.isIndexCreated();
if (!exists) {
const sourcePath = path.join(req.user.directories.vectors, sanitize(source), sanitize(collectionId));
if (!fs.existsSync(sourcePath)) {
continue;
}
const path = index.folderPath;
await index.deleteIndex();
console.log(`Deleted vector index at ${path}`);
await fs.promises.rm(sourcePath, { recursive: true });
console.log(`Deleted vector index at ${sourcePath}`);
}
return res.sendStatus(200);