mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-03-09 00:17:47 +01:00
Add vector retrieval score threshold
This commit is contained in:
parent
039f3b875b
commit
967a7980f5
@ -52,6 +52,7 @@ const settings = {
|
||||
insert: 3,
|
||||
query: 2,
|
||||
message_chunk_size: 400,
|
||||
score_threshold: 0.25,
|
||||
|
||||
// For files
|
||||
enabled_files: false,
|
||||
@ -760,6 +761,7 @@ async function queryCollection(collectionId, searchText, topK) {
|
||||
searchText: searchText,
|
||||
topK: topK,
|
||||
source: settings.source,
|
||||
threshold: settings.score_threshold,
|
||||
}),
|
||||
});
|
||||
|
||||
@ -788,6 +790,7 @@ async function queryMultipleCollections(collectionIds, searchText, topK) {
|
||||
searchText: searchText,
|
||||
topK: topK,
|
||||
source: settings.source,
|
||||
threshold: settings.score_threshold,
|
||||
}),
|
||||
});
|
||||
|
||||
@ -1310,6 +1313,12 @@ jQuery(async () => {
|
||||
saveSettingsDebounced();
|
||||
});
|
||||
|
||||
$('#vectors_score_threshold').val(settings.score_threshold).on('input', () => {
|
||||
settings.score_threshold = Number($('#vectors_score_threshold').val());
|
||||
Object.assign(extension_settings.vectors, settings);
|
||||
saveSettingsDebounced();
|
||||
});
|
||||
|
||||
const validSecret = !!secret_state[SECRET_KEYS.NOMICAI];
|
||||
const placeholder = validSecret ? '✔️ Key saved' : '❌ Missing key';
|
||||
$('#api_key_nomicai').attr('placeholder', placeholder);
|
||||
|
@ -81,11 +81,19 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="flex-container flexFlowColumn" title="How many last messages will be matched for relevance.">
|
||||
<label for="vectors_query">
|
||||
<span>Query messages</span>
|
||||
</label>
|
||||
<input type="number" id="vectors_query" class="text_pole widthUnset" min="1" max="99" />
|
||||
<div class="flex-container marginTopBot5">
|
||||
<div class="flex-container flex1 flexFlowColumn" title="How many last messages will be matched for relevance.">
|
||||
<label for="vectors_query">
|
||||
<span>Query messages</span>
|
||||
</label>
|
||||
<input type="number" id="vectors_query" class="text_pole widthUnset" min="1" max="99" />
|
||||
</div>
|
||||
<div class="flex-container flex1 flexFlowColumn" title="Cut-off score for relevance. Helps to filter out irrelevant data.">
|
||||
<label for="vectors_query">
|
||||
<span>Score threshold</span>
|
||||
</label>
|
||||
<input type="number" id="vectors_score_threshold" class="text_pole widthUnset" min="0" max="1" step="0.05" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="flex-container">
|
||||
|
@ -168,14 +168,15 @@ async function deleteVectorItems(directories, collectionId, source, hashes) {
|
||||
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
||||
* @param {string} searchText - The text to search for
|
||||
* @param {number} topK - The number of results to return
|
||||
* @param {number} threshold - The threshold for the search
|
||||
* @returns {Promise<{hashes: number[], metadata: object[]}>} - The metadata of the items that match the search text
|
||||
*/
|
||||
async function queryCollection(directories, collectionId, source, sourceSettings, searchText, topK) {
|
||||
async function queryCollection(directories, collectionId, source, sourceSettings, searchText, topK, threshold) {
|
||||
const store = await getIndex(directories, collectionId, source);
|
||||
const vector = await getVector(source, sourceSettings, searchText, true, directories);
|
||||
|
||||
const result = await store.queryItems(vector, topK);
|
||||
const metadata = result.map(x => x.item.metadata);
|
||||
const metadata = result.filter(x => x.score >= threshold).map(x => x.item.metadata);
|
||||
const hashes = result.map(x => Number(x.item.metadata.hash));
|
||||
return { metadata, hashes };
|
||||
}
|
||||
@ -188,9 +189,11 @@ async function queryCollection(directories, collectionId, source, sourceSettings
|
||||
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
||||
* @param {string} searchText - The text to search for
|
||||
* @param {number} topK - The number of results to return
|
||||
* @param {number} threshold - The threshold for the search
|
||||
*
|
||||
* @returns {Promise<Record<string, { hashes: number[], metadata: object[] }>>} - The top K results from each collection
|
||||
*/
|
||||
async function multiQueryCollection(directories, collectionIds, source, sourceSettings, searchText, topK) {
|
||||
async function multiQueryCollection(directories, collectionIds, source, sourceSettings, searchText, topK, threshold) {
|
||||
const vector = await getVector(source, sourceSettings, searchText, true, directories);
|
||||
const results = [];
|
||||
|
||||
@ -200,9 +203,10 @@ async function multiQueryCollection(directories, collectionIds, source, sourceSe
|
||||
results.push(...result.map(result => ({ collectionId, result })));
|
||||
}
|
||||
|
||||
// Sort results by descending similarity
|
||||
// Sort results by descending similarity, apply threshold, and take top K
|
||||
const sortedResults = results
|
||||
.sort((a, b) => b.result.score - a.result.score)
|
||||
.filter(x => x.result.score >= threshold)
|
||||
.slice(0, topK);
|
||||
|
||||
/**
|
||||
@ -274,10 +278,11 @@ router.post('/query', jsonParser, async (req, res) => {
|
||||
const collectionId = String(req.body.collectionId);
|
||||
const searchText = String(req.body.searchText);
|
||||
const topK = Number(req.body.topK) || 10;
|
||||
const threshold = Number(req.body.threshold) || 0.0;
|
||||
const source = String(req.body.source) || 'transformers';
|
||||
const sourceSettings = getSourceSettings(source, req);
|
||||
|
||||
const results = await queryCollection(req.user.directories, collectionId, source, sourceSettings, searchText, topK);
|
||||
const results = await queryCollection(req.user.directories, collectionId, source, sourceSettings, searchText, topK, threshold);
|
||||
return res.json(results);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
@ -294,10 +299,11 @@ router.post('/query-multi', jsonParser, async (req, res) => {
|
||||
const collectionIds = req.body.collectionIds.map(x => String(x));
|
||||
const searchText = String(req.body.searchText);
|
||||
const topK = Number(req.body.topK) || 10;
|
||||
const threshold = Number(req.body.threshold) || 0.0;
|
||||
const source = String(req.body.source) || 'transformers';
|
||||
const sourceSettings = getSourceSettings(source, req);
|
||||
|
||||
const results = await multiQueryCollection(req.user.directories, collectionIds, source, sourceSettings, searchText, topK);
|
||||
const results = await multiQueryCollection(req.user.directories, collectionIds, source, sourceSettings, searchText, topK, threshold);
|
||||
return res.json(results);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
|
Loading…
x
Reference in New Issue
Block a user