From 967a7980f5f708badcc6b2f386858764c146cac4 Mon Sep 17 00:00:00 2001
From: Cohee <18619528+Cohee1207@users.noreply.github.com>
Date: Thu, 23 May 2024 17:28:43 +0300
Subject: [PATCH] Add vector retrieval score threshold
---
public/scripts/extensions/vectors/index.js | 9 +++++++++
.../scripts/extensions/vectors/settings.html | 18 +++++++++++++-----
src/endpoints/vectors.js | 18 ++++++++++++------
3 files changed, 34 insertions(+), 11 deletions(-)
diff --git a/public/scripts/extensions/vectors/index.js b/public/scripts/extensions/vectors/index.js
index 1ed5bae93..a68faeb2e 100644
--- a/public/scripts/extensions/vectors/index.js
+++ b/public/scripts/extensions/vectors/index.js
@@ -52,6 +52,7 @@ const settings = {
insert: 3,
query: 2,
message_chunk_size: 400,
+ score_threshold: 0.25,
// For files
enabled_files: false,
@@ -760,6 +761,7 @@ async function queryCollection(collectionId, searchText, topK) {
searchText: searchText,
topK: topK,
source: settings.source,
+ threshold: settings.score_threshold,
}),
});
@@ -788,6 +790,7 @@ async function queryMultipleCollections(collectionIds, searchText, topK) {
searchText: searchText,
topK: topK,
source: settings.source,
+ threshold: settings.score_threshold,
}),
});
@@ -1310,6 +1313,12 @@ jQuery(async () => {
saveSettingsDebounced();
});
+ $('#vectors_score_threshold').val(settings.score_threshold).on('input', () => {
+ settings.score_threshold = Number($('#vectors_score_threshold').val());
+ Object.assign(extension_settings.vectors, settings);
+ saveSettingsDebounced();
+ });
+
const validSecret = !!secret_state[SECRET_KEYS.NOMICAI];
const placeholder = validSecret ? '✔️ Key saved' : '❌ Missing key';
$('#api_key_nomicai').attr('placeholder', placeholder);
diff --git a/public/scripts/extensions/vectors/settings.html b/public/scripts/extensions/vectors/settings.html
index bcaf1c06e..efb78f6b9 100644
--- a/public/scripts/extensions/vectors/settings.html
+++ b/public/scripts/extensions/vectors/settings.html
@@ -81,11 +81,19 @@
-
-
-
+
diff --git a/src/endpoints/vectors.js b/src/endpoints/vectors.js
index 990796fb1..519b4c284 100644
--- a/src/endpoints/vectors.js
+++ b/src/endpoints/vectors.js
@@ -168,14 +168,15 @@ async function deleteVectorItems(directories, collectionId, source, hashes) {
* @param {Object} sourceSettings - Settings for the source, if it needs any
* @param {string} searchText - The text to search for
* @param {number} topK - The number of results to return
+ * @param {number} threshold - The threshold for the search
* @returns {Promise<{hashes: number[], metadata: object[]}>} - The metadata of the items that match the search text
*/
-async function queryCollection(directories, collectionId, source, sourceSettings, searchText, topK) {
+async function queryCollection(directories, collectionId, source, sourceSettings, searchText, topK, threshold) {
const store = await getIndex(directories, collectionId, source);
const vector = await getVector(source, sourceSettings, searchText, true, directories);
const result = await store.queryItems(vector, topK);
- const metadata = result.map(x => x.item.metadata);
+ const metadata = result.filter(x => x.score >= threshold).map(x => x.item.metadata);
const hashes = result.map(x => Number(x.item.metadata.hash));
return { metadata, hashes };
}
@@ -188,9 +189,11 @@ async function queryCollection(directories, collectionId, source, sourceSettings
* @param {Object} sourceSettings - Settings for the source, if it needs any
* @param {string} searchText - The text to search for
* @param {number} topK - The number of results to return
+ * @param {number} threshold - The threshold for the search
+ *
* @returns {Promise>} - The top K results from each collection
*/
-async function multiQueryCollection(directories, collectionIds, source, sourceSettings, searchText, topK) {
+async function multiQueryCollection(directories, collectionIds, source, sourceSettings, searchText, topK, threshold) {
const vector = await getVector(source, sourceSettings, searchText, true, directories);
const results = [];
@@ -200,9 +203,10 @@ async function multiQueryCollection(directories, collectionIds, source, sourceSe
results.push(...result.map(result => ({ collectionId, result })));
}
- // Sort results by descending similarity
+ // Sort results by descending similarity, apply threshold, and take top K
const sortedResults = results
.sort((a, b) => b.result.score - a.result.score)
+ .filter(x => x.result.score >= threshold)
.slice(0, topK);
/**
@@ -274,10 +278,11 @@ router.post('/query', jsonParser, async (req, res) => {
const collectionId = String(req.body.collectionId);
const searchText = String(req.body.searchText);
const topK = Number(req.body.topK) || 10;
+ const threshold = Number(req.body.threshold) || 0.0;
const source = String(req.body.source) || 'transformers';
const sourceSettings = getSourceSettings(source, req);
- const results = await queryCollection(req.user.directories, collectionId, source, sourceSettings, searchText, topK);
+ const results = await queryCollection(req.user.directories, collectionId, source, sourceSettings, searchText, topK, threshold);
return res.json(results);
} catch (error) {
console.error(error);
@@ -294,10 +299,11 @@ router.post('/query-multi', jsonParser, async (req, res) => {
const collectionIds = req.body.collectionIds.map(x => String(x));
const searchText = String(req.body.searchText);
const topK = Number(req.body.topK) || 10;
+ const threshold = Number(req.body.threshold) || 0.0;
const source = String(req.body.source) || 'transformers';
const sourceSettings = getSourceSettings(source, req);
- const results = await multiQueryCollection(req.user.directories, collectionIds, source, sourceSettings, searchText, topK);
+ const results = await multiQueryCollection(req.user.directories, collectionIds, source, sourceSettings, searchText, topK, threshold);
return res.json(results);
} catch (error) {
console.error(error);