From 967a7980f5f708badcc6b2f386858764c146cac4 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Thu, 23 May 2024 17:28:43 +0300 Subject: [PATCH] Add vector retrieval score threshold --- public/scripts/extensions/vectors/index.js | 9 +++++++++ .../scripts/extensions/vectors/settings.html | 18 +++++++++++++----- src/endpoints/vectors.js | 18 ++++++++++++------ 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/public/scripts/extensions/vectors/index.js b/public/scripts/extensions/vectors/index.js index 1ed5bae93..a68faeb2e 100644 --- a/public/scripts/extensions/vectors/index.js +++ b/public/scripts/extensions/vectors/index.js @@ -52,6 +52,7 @@ const settings = { insert: 3, query: 2, message_chunk_size: 400, + score_threshold: 0.25, // For files enabled_files: false, @@ -760,6 +761,7 @@ async function queryCollection(collectionId, searchText, topK) { searchText: searchText, topK: topK, source: settings.source, + threshold: settings.score_threshold, }), }); @@ -788,6 +790,7 @@ async function queryMultipleCollections(collectionIds, searchText, topK) { searchText: searchText, topK: topK, source: settings.source, + threshold: settings.score_threshold, }), }); @@ -1310,6 +1313,12 @@ jQuery(async () => { saveSettingsDebounced(); }); + $('#vectors_score_threshold').val(settings.score_threshold).on('input', () => { + settings.score_threshold = Number($('#vectors_score_threshold').val()); + Object.assign(extension_settings.vectors, settings); + saveSettingsDebounced(); + }); + const validSecret = !!secret_state[SECRET_KEYS.NOMICAI]; const placeholder = validSecret ? '✔️ Key saved' : '❌ Missing key'; $('#api_key_nomicai').attr('placeholder', placeholder); diff --git a/public/scripts/extensions/vectors/settings.html b/public/scripts/extensions/vectors/settings.html index bcaf1c06e..efb78f6b9 100644 --- a/public/scripts/extensions/vectors/settings.html +++ b/public/scripts/extensions/vectors/settings.html @@ -81,11 +81,19 @@ -
- - +
+
+ + +
+
+ + +
diff --git a/src/endpoints/vectors.js b/src/endpoints/vectors.js index 990796fb1..519b4c284 100644 --- a/src/endpoints/vectors.js +++ b/src/endpoints/vectors.js @@ -168,14 +168,15 @@ async function deleteVectorItems(directories, collectionId, source, hashes) { * @param {Object} sourceSettings - Settings for the source, if it needs any * @param {string} searchText - The text to search for * @param {number} topK - The number of results to return + * @param {number} threshold - The threshold for the search * @returns {Promise<{hashes: number[], metadata: object[]}>} - The metadata of the items that match the search text */ -async function queryCollection(directories, collectionId, source, sourceSettings, searchText, topK) { +async function queryCollection(directories, collectionId, source, sourceSettings, searchText, topK, threshold) { const store = await getIndex(directories, collectionId, source); const vector = await getVector(source, sourceSettings, searchText, true, directories); const result = await store.queryItems(vector, topK); - const metadata = result.map(x => x.item.metadata); + const metadata = result.filter(x => x.score >= threshold).map(x => x.item.metadata); const hashes = result.map(x => Number(x.item.metadata.hash)); return { metadata, hashes }; } @@ -188,9 +189,11 @@ async function queryCollection(directories, collectionId, source, sourceSettings * @param {Object} sourceSettings - Settings for the source, if it needs any * @param {string} searchText - The text to search for * @param {number} topK - The number of results to return + * @param {number} threshold - The threshold for the search + * * @returns {Promise>} - The top K results from each collection */ -async function multiQueryCollection(directories, collectionIds, source, sourceSettings, searchText, topK) { +async function multiQueryCollection(directories, collectionIds, source, sourceSettings, searchText, topK, threshold) { const vector = await getVector(source, sourceSettings, searchText, true, directories); const results = []; @@ -200,9 +203,10 @@ async function multiQueryCollection(directories, collectionIds, source, sourceSe results.push(...result.map(result => ({ collectionId, result }))); } - // Sort results by descending similarity + // Sort results by descending similarity, apply threshold, and take top K const sortedResults = results .sort((a, b) => b.result.score - a.result.score) + .filter(x => x.result.score >= threshold) .slice(0, topK); /** @@ -274,10 +278,11 @@ router.post('/query', jsonParser, async (req, res) => { const collectionId = String(req.body.collectionId); const searchText = String(req.body.searchText); const topK = Number(req.body.topK) || 10; + const threshold = Number(req.body.threshold) || 0.0; const source = String(req.body.source) || 'transformers'; const sourceSettings = getSourceSettings(source, req); - const results = await queryCollection(req.user.directories, collectionId, source, sourceSettings, searchText, topK); + const results = await queryCollection(req.user.directories, collectionId, source, sourceSettings, searchText, topK, threshold); return res.json(results); } catch (error) { console.error(error); @@ -294,10 +299,11 @@ router.post('/query-multi', jsonParser, async (req, res) => { const collectionIds = req.body.collectionIds.map(x => String(x)); const searchText = String(req.body.searchText); const topK = Number(req.body.topK) || 10; + const threshold = Number(req.body.threshold) || 0.0; const source = String(req.body.source) || 'transformers'; const sourceSettings = getSourceSettings(source, req); - const results = await multiQueryCollection(req.user.directories, collectionIds, source, sourceSettings, searchText, topK); + const results = await multiQueryCollection(req.user.directories, collectionIds, source, sourceSettings, searchText, topK, threshold); return res.json(results); } catch (error) { console.error(error);