mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-06-05 21:59:27 +02:00
Implement Data Bank vectors querying
This commit is contained in:
@@ -5,7 +5,7 @@ const sanitize = require('sanitize-filename');
|
||||
const { jsonParser } = require('../express-common');
|
||||
|
||||
// Don't forget to add new sources to the SOURCES array
|
||||
const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm', 'togetherai', 'nomicai'];
|
||||
const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm', 'togetherai', 'nomicai', 'cohere'];
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from the given source.
|
||||
@@ -55,7 +55,7 @@ async function getBatchVector(source, sourceSettings, texts, directories) {
|
||||
case 'togetherai':
|
||||
case 'mistral':
|
||||
case 'openai':
|
||||
results.push(...await require('../openai-vectors').getOpenAIBatchVector(batch, source, sourceSettings.model));
|
||||
results.push(...await require('../openai-vectors').getOpenAIBatchVector(batch, source, directories, sourceSettings.model));
|
||||
break;
|
||||
case 'transformers':
|
||||
results.push(...await require('../embedding').getTransformersBatchVector(batch));
|
||||
@@ -155,6 +155,7 @@ async function deleteVectorItems(directories, collectionId, source, hashes) {
|
||||
|
||||
/**
|
||||
* Gets the hashes of the items in the vector collection that match the search text
|
||||
* @param {import('../users').UserDirectoryList} directories - User directories
|
||||
* @param {string} collectionId - The collection ID
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
||||
@@ -172,6 +173,48 @@ async function queryCollection(directories, collectionId, source, sourceSettings
|
||||
return { metadata, hashes };
|
||||
}
|
||||
|
||||
/**
|
||||
* Queries multiple collections for the given search queries. Returns the overall top K results.
|
||||
* @param {import('../users').UserDirectoryList} directories - User directories
|
||||
* @param {string[]} collectionIds - The collection IDs to query
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
||||
* @param {string} searchText - The text to search for
|
||||
* @param {number} topK - The number of results to return
|
||||
* @returns {Promise<Record<string, { hashes: number[], metadata: object[] }>>} - The top K results from each collection
|
||||
*/
|
||||
async function multiQueryCollection(directories, collectionIds, source, sourceSettings, searchText, topK) {
|
||||
const vector = await getVector(source, sourceSettings, searchText, directories);
|
||||
const results = [];
|
||||
|
||||
for (const collectionId of collectionIds) {
|
||||
const store = await getIndex(directories, collectionId, source);
|
||||
const result = await store.queryItems(vector, topK);
|
||||
results.push(...result.map(result => ({ collectionId, result })));
|
||||
}
|
||||
|
||||
// Sort results by descending similarity
|
||||
const sortedResults = results
|
||||
.sort((a, b) => b.result.score - a.result.score)
|
||||
.slice(0, topK);
|
||||
|
||||
/**
|
||||
* Group the results by collection ID
|
||||
* @type {Record<string, { hashes: number[], metadata: object[] }>}
|
||||
*/
|
||||
const groupedResults = {};
|
||||
for (const result of sortedResults) {
|
||||
if (!groupedResults[result.collectionId]) {
|
||||
groupedResults[result.collectionId] = { hashes: [], metadata: [] };
|
||||
}
|
||||
|
||||
groupedResults[result.collectionId].hashes.push(Number(result.result.item.metadata.hash));
|
||||
groupedResults[result.collectionId].metadata.push(result.result.item.metadata);
|
||||
}
|
||||
|
||||
return groupedResults;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts settings for the vectorization sources from the HTTP request headers.
|
||||
* @param {string} source - Which source to extract settings for.
|
||||
@@ -229,6 +272,26 @@ router.post('/query', jsonParser, async (req, res) => {
|
||||
}
|
||||
});
|
||||
|
||||
router.post('/query-multi', jsonParser, async (req, res) => {
|
||||
try {
|
||||
if (!Array.isArray(req.body.collectionIds) || !req.body.searchText) {
|
||||
return res.sendStatus(400);
|
||||
}
|
||||
|
||||
const collectionIds = req.body.collectionIds.map(x => String(x));
|
||||
const searchText = String(req.body.searchText);
|
||||
const topK = Number(req.body.topK) || 10;
|
||||
const source = String(req.body.source) || 'transformers';
|
||||
const sourceSettings = getSourceSettings(source, req);
|
||||
|
||||
const results = await multiQueryCollection(req.user.directories, collectionIds, source, sourceSettings, searchText, topK);
|
||||
return res.json(results);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
return res.sendStatus(500);
|
||||
}
|
||||
});
|
||||
|
||||
router.post('/insert', jsonParser, async (req, res) => {
|
||||
try {
|
||||
if (!Array.isArray(req.body.items) || !req.body.collectionId) {
|
||||
|
Reference in New Issue
Block a user