Add vector retrieval score threshold

This commit is contained in:
Cohee 2024-05-23 17:28:43 +03:00
parent 039f3b875b
commit 967a7980f5
3 changed files with 34 additions and 11 deletions

View File

@ -52,6 +52,7 @@ const settings = {
insert: 3,
query: 2,
message_chunk_size: 400,
score_threshold: 0.25,
// For files
enabled_files: false,
@ -760,6 +761,7 @@ async function queryCollection(collectionId, searchText, topK) {
searchText: searchText,
topK: topK,
source: settings.source,
threshold: settings.score_threshold,
}),
});
@ -788,6 +790,7 @@ async function queryMultipleCollections(collectionIds, searchText, topK) {
searchText: searchText,
topK: topK,
source: settings.source,
threshold: settings.score_threshold,
}),
});
@ -1310,6 +1313,12 @@ jQuery(async () => {
saveSettingsDebounced();
});
$('#vectors_score_threshold').val(settings.score_threshold).on('input', () => {
settings.score_threshold = Number($('#vectors_score_threshold').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
const validSecret = !!secret_state[SECRET_KEYS.NOMICAI];
const placeholder = validSecret ? '✔️ Key saved' : '❌ Missing key';
$('#api_key_nomicai').attr('placeholder', placeholder);

View File

@ -81,11 +81,19 @@
</div>
</div>
<div class="flex-container flexFlowColumn" title="How many last messages will be matched for relevance.">
<label for="vectors_query">
<span>Query messages</span>
</label>
<input type="number" id="vectors_query" class="text_pole widthUnset" min="1" max="99" />
<div class="flex-container marginTopBot5">
<div class="flex-container flex1 flexFlowColumn" title="How many last messages will be matched for relevance.">
<label for="vectors_query">
<span>Query messages</span>
</label>
<input type="number" id="vectors_query" class="text_pole widthUnset" min="1" max="99" />
</div>
<div class="flex-container flex1 flexFlowColumn" title="Cut-off score for relevance. Helps to filter out irrelevant data.">
<label for="vectors_query">
<span>Score threshold</span>
</label>
<input type="number" id="vectors_score_threshold" class="text_pole widthUnset" min="0" max="1" step="0.05" />
</div>
</div>
<div class="flex-container">

View File

@ -168,14 +168,15 @@ async function deleteVectorItems(directories, collectionId, source, hashes) {
* @param {Object} sourceSettings - Settings for the source, if it needs any
* @param {string} searchText - The text to search for
* @param {number} topK - The number of results to return
* @param {number} threshold - The threshold for the search
* @returns {Promise<{hashes: number[], metadata: object[]}>} - The metadata of the items that match the search text
*/
async function queryCollection(directories, collectionId, source, sourceSettings, searchText, topK) {
async function queryCollection(directories, collectionId, source, sourceSettings, searchText, topK, threshold) {
const store = await getIndex(directories, collectionId, source);
const vector = await getVector(source, sourceSettings, searchText, true, directories);
const result = await store.queryItems(vector, topK);
const metadata = result.map(x => x.item.metadata);
const metadata = result.filter(x => x.score >= threshold).map(x => x.item.metadata);
const hashes = result.map(x => Number(x.item.metadata.hash));
return { metadata, hashes };
}
@ -188,9 +189,11 @@ async function queryCollection(directories, collectionId, source, sourceSettings
* @param {Object} sourceSettings - Settings for the source, if it needs any
* @param {string} searchText - The text to search for
* @param {number} topK - The number of results to return
* @param {number} threshold - The threshold for the search
*
* @returns {Promise<Record<string, { hashes: number[], metadata: object[] }>>} - The top K results from each collection
*/
async function multiQueryCollection(directories, collectionIds, source, sourceSettings, searchText, topK) {
async function multiQueryCollection(directories, collectionIds, source, sourceSettings, searchText, topK, threshold) {
const vector = await getVector(source, sourceSettings, searchText, true, directories);
const results = [];
@ -200,9 +203,10 @@ async function multiQueryCollection(directories, collectionIds, source, sourceSe
results.push(...result.map(result => ({ collectionId, result })));
}
// Sort results by descending similarity
// Sort results by descending similarity, apply threshold, and take top K
const sortedResults = results
.sort((a, b) => b.result.score - a.result.score)
.filter(x => x.result.score >= threshold)
.slice(0, topK);
/**
@ -274,10 +278,11 @@ router.post('/query', jsonParser, async (req, res) => {
const collectionId = String(req.body.collectionId);
const searchText = String(req.body.searchText);
const topK = Number(req.body.topK) || 10;
const threshold = Number(req.body.threshold) || 0.0;
const source = String(req.body.source) || 'transformers';
const sourceSettings = getSourceSettings(source, req);
const results = await queryCollection(req.user.directories, collectionId, source, sourceSettings, searchText, topK);
const results = await queryCollection(req.user.directories, collectionId, source, sourceSettings, searchText, topK, threshold);
return res.json(results);
} catch (error) {
console.error(error);
@ -294,10 +299,11 @@ router.post('/query-multi', jsonParser, async (req, res) => {
const collectionIds = req.body.collectionIds.map(x => String(x));
const searchText = String(req.body.searchText);
const topK = Number(req.body.topK) || 10;
const threshold = Number(req.body.threshold) || 0.0;
const source = String(req.body.source) || 'transformers';
const sourceSettings = getSourceSettings(source, req);
const results = await multiQueryCollection(req.user.directories, collectionIds, source, sourceSettings, searchText, topK);
const results = await multiQueryCollection(req.user.directories, collectionIds, source, sourceSettings, searchText, topK, threshold);
return res.json(results);
} catch (error) {
console.error(error);