mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-03-09 16:40:10 +01:00
Add vector retrieval score threshold
This commit is contained in:
parent
039f3b875b
commit
967a7980f5
@ -52,6 +52,7 @@ const settings = {
|
|||||||
insert: 3,
|
insert: 3,
|
||||||
query: 2,
|
query: 2,
|
||||||
message_chunk_size: 400,
|
message_chunk_size: 400,
|
||||||
|
score_threshold: 0.25,
|
||||||
|
|
||||||
// For files
|
// For files
|
||||||
enabled_files: false,
|
enabled_files: false,
|
||||||
@ -760,6 +761,7 @@ async function queryCollection(collectionId, searchText, topK) {
|
|||||||
searchText: searchText,
|
searchText: searchText,
|
||||||
topK: topK,
|
topK: topK,
|
||||||
source: settings.source,
|
source: settings.source,
|
||||||
|
threshold: settings.score_threshold,
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -788,6 +790,7 @@ async function queryMultipleCollections(collectionIds, searchText, topK) {
|
|||||||
searchText: searchText,
|
searchText: searchText,
|
||||||
topK: topK,
|
topK: topK,
|
||||||
source: settings.source,
|
source: settings.source,
|
||||||
|
threshold: settings.score_threshold,
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -1310,6 +1313,12 @@ jQuery(async () => {
|
|||||||
saveSettingsDebounced();
|
saveSettingsDebounced();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
$('#vectors_score_threshold').val(settings.score_threshold).on('input', () => {
|
||||||
|
settings.score_threshold = Number($('#vectors_score_threshold').val());
|
||||||
|
Object.assign(extension_settings.vectors, settings);
|
||||||
|
saveSettingsDebounced();
|
||||||
|
});
|
||||||
|
|
||||||
const validSecret = !!secret_state[SECRET_KEYS.NOMICAI];
|
const validSecret = !!secret_state[SECRET_KEYS.NOMICAI];
|
||||||
const placeholder = validSecret ? '✔️ Key saved' : '❌ Missing key';
|
const placeholder = validSecret ? '✔️ Key saved' : '❌ Missing key';
|
||||||
$('#api_key_nomicai').attr('placeholder', placeholder);
|
$('#api_key_nomicai').attr('placeholder', placeholder);
|
||||||
|
@ -81,11 +81,19 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="flex-container flexFlowColumn" title="How many last messages will be matched for relevance.">
|
<div class="flex-container marginTopBot5">
|
||||||
<label for="vectors_query">
|
<div class="flex-container flex1 flexFlowColumn" title="How many last messages will be matched for relevance.">
|
||||||
<span>Query messages</span>
|
<label for="vectors_query">
|
||||||
</label>
|
<span>Query messages</span>
|
||||||
<input type="number" id="vectors_query" class="text_pole widthUnset" min="1" max="99" />
|
</label>
|
||||||
|
<input type="number" id="vectors_query" class="text_pole widthUnset" min="1" max="99" />
|
||||||
|
</div>
|
||||||
|
<div class="flex-container flex1 flexFlowColumn" title="Cut-off score for relevance. Helps to filter out irrelevant data.">
|
||||||
|
<label for="vectors_query">
|
||||||
|
<span>Score threshold</span>
|
||||||
|
</label>
|
||||||
|
<input type="number" id="vectors_score_threshold" class="text_pole widthUnset" min="0" max="1" step="0.05" />
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="flex-container">
|
<div class="flex-container">
|
||||||
|
@ -168,14 +168,15 @@ async function deleteVectorItems(directories, collectionId, source, hashes) {
|
|||||||
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
||||||
* @param {string} searchText - The text to search for
|
* @param {string} searchText - The text to search for
|
||||||
* @param {number} topK - The number of results to return
|
* @param {number} topK - The number of results to return
|
||||||
|
* @param {number} threshold - The threshold for the search
|
||||||
* @returns {Promise<{hashes: number[], metadata: object[]}>} - The metadata of the items that match the search text
|
* @returns {Promise<{hashes: number[], metadata: object[]}>} - The metadata of the items that match the search text
|
||||||
*/
|
*/
|
||||||
async function queryCollection(directories, collectionId, source, sourceSettings, searchText, topK) {
|
async function queryCollection(directories, collectionId, source, sourceSettings, searchText, topK, threshold) {
|
||||||
const store = await getIndex(directories, collectionId, source);
|
const store = await getIndex(directories, collectionId, source);
|
||||||
const vector = await getVector(source, sourceSettings, searchText, true, directories);
|
const vector = await getVector(source, sourceSettings, searchText, true, directories);
|
||||||
|
|
||||||
const result = await store.queryItems(vector, topK);
|
const result = await store.queryItems(vector, topK);
|
||||||
const metadata = result.map(x => x.item.metadata);
|
const metadata = result.filter(x => x.score >= threshold).map(x => x.item.metadata);
|
||||||
const hashes = result.map(x => Number(x.item.metadata.hash));
|
const hashes = result.map(x => Number(x.item.metadata.hash));
|
||||||
return { metadata, hashes };
|
return { metadata, hashes };
|
||||||
}
|
}
|
||||||
@ -188,9 +189,11 @@ async function queryCollection(directories, collectionId, source, sourceSettings
|
|||||||
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
||||||
* @param {string} searchText - The text to search for
|
* @param {string} searchText - The text to search for
|
||||||
* @param {number} topK - The number of results to return
|
* @param {number} topK - The number of results to return
|
||||||
|
* @param {number} threshold - The threshold for the search
|
||||||
|
*
|
||||||
* @returns {Promise<Record<string, { hashes: number[], metadata: object[] }>>} - The top K results from each collection
|
* @returns {Promise<Record<string, { hashes: number[], metadata: object[] }>>} - The top K results from each collection
|
||||||
*/
|
*/
|
||||||
async function multiQueryCollection(directories, collectionIds, source, sourceSettings, searchText, topK) {
|
async function multiQueryCollection(directories, collectionIds, source, sourceSettings, searchText, topK, threshold) {
|
||||||
const vector = await getVector(source, sourceSettings, searchText, true, directories);
|
const vector = await getVector(source, sourceSettings, searchText, true, directories);
|
||||||
const results = [];
|
const results = [];
|
||||||
|
|
||||||
@ -200,9 +203,10 @@ async function multiQueryCollection(directories, collectionIds, source, sourceSe
|
|||||||
results.push(...result.map(result => ({ collectionId, result })));
|
results.push(...result.map(result => ({ collectionId, result })));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort results by descending similarity
|
// Sort results by descending similarity, apply threshold, and take top K
|
||||||
const sortedResults = results
|
const sortedResults = results
|
||||||
.sort((a, b) => b.result.score - a.result.score)
|
.sort((a, b) => b.result.score - a.result.score)
|
||||||
|
.filter(x => x.result.score >= threshold)
|
||||||
.slice(0, topK);
|
.slice(0, topK);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -274,10 +278,11 @@ router.post('/query', jsonParser, async (req, res) => {
|
|||||||
const collectionId = String(req.body.collectionId);
|
const collectionId = String(req.body.collectionId);
|
||||||
const searchText = String(req.body.searchText);
|
const searchText = String(req.body.searchText);
|
||||||
const topK = Number(req.body.topK) || 10;
|
const topK = Number(req.body.topK) || 10;
|
||||||
|
const threshold = Number(req.body.threshold) || 0.0;
|
||||||
const source = String(req.body.source) || 'transformers';
|
const source = String(req.body.source) || 'transformers';
|
||||||
const sourceSettings = getSourceSettings(source, req);
|
const sourceSettings = getSourceSettings(source, req);
|
||||||
|
|
||||||
const results = await queryCollection(req.user.directories, collectionId, source, sourceSettings, searchText, topK);
|
const results = await queryCollection(req.user.directories, collectionId, source, sourceSettings, searchText, topK, threshold);
|
||||||
return res.json(results);
|
return res.json(results);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(error);
|
console.error(error);
|
||||||
@ -294,10 +299,11 @@ router.post('/query-multi', jsonParser, async (req, res) => {
|
|||||||
const collectionIds = req.body.collectionIds.map(x => String(x));
|
const collectionIds = req.body.collectionIds.map(x => String(x));
|
||||||
const searchText = String(req.body.searchText);
|
const searchText = String(req.body.searchText);
|
||||||
const topK = Number(req.body.topK) || 10;
|
const topK = Number(req.body.topK) || 10;
|
||||||
|
const threshold = Number(req.body.threshold) || 0.0;
|
||||||
const source = String(req.body.source) || 'transformers';
|
const source = String(req.body.source) || 'transformers';
|
||||||
const sourceSettings = getSourceSettings(source, req);
|
const sourceSettings = getSourceSettings(source, req);
|
||||||
|
|
||||||
const results = await multiQueryCollection(req.user.directories, collectionIds, source, sourceSettings, searchText, topK);
|
const results = await multiQueryCollection(req.user.directories, collectionIds, source, sourceSettings, searchText, topK, threshold);
|
||||||
return res.json(results);
|
return res.json(results);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(error);
|
console.error(error);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user