mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-06-05 21:59:27 +02:00
Update all endpoints to use user directories
This commit is contained in:
@@ -12,22 +12,23 @@ const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm', 'togethe
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
||||
* @param {string} text - The text to get the vector for
|
||||
* @param {import('../users').UserDirectoryList} directories - The directories object for the user
|
||||
* @returns {Promise<number[]>} - The vector for the text
|
||||
*/
|
||||
async function getVector(source, sourceSettings, text) {
|
||||
async function getVector(source, sourceSettings, text, directories) {
|
||||
switch (source) {
|
||||
case 'nomicai':
|
||||
return require('../nomicai-vectors').getNomicAIVector(text, source);
|
||||
return require('../nomicai-vectors').getNomicAIVector(text, source, directories);
|
||||
case 'togetherai':
|
||||
case 'mistral':
|
||||
case 'openai':
|
||||
return require('../openai-vectors').getOpenAIVector(text, source, sourceSettings.model);
|
||||
return require('../openai-vectors').getOpenAIVector(text, source, directories, sourceSettings.model);
|
||||
case 'transformers':
|
||||
return require('../embedding').getTransformersVector(text);
|
||||
case 'extras':
|
||||
return require('../extras-vectors').getExtrasVector(text, sourceSettings.extrasUrl, sourceSettings.extrasKey);
|
||||
case 'palm':
|
||||
return require('../makersuite-vectors').getMakerSuiteVector(text);
|
||||
return require('../makersuite-vectors').getMakerSuiteVector(text, directories);
|
||||
}
|
||||
|
||||
throw new Error(`Unknown vector source ${source}`);
|
||||
@@ -38,9 +39,10 @@ async function getVector(source, sourceSettings, text) {
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
||||
* @param {string[]} texts - The array of texts to get the vector for
|
||||
* @param {import('../users').UserDirectoryList} directories - The directories object for the user
|
||||
* @returns {Promise<number[][]>} - The array of vectors for the texts
|
||||
*/
|
||||
async function getBatchVector(source, sourceSettings, texts) {
|
||||
async function getBatchVector(source, sourceSettings, texts, directories) {
|
||||
const batchSize = 10;
|
||||
const batches = Array(Math.ceil(texts.length / batchSize)).fill(undefined).map((_, i) => texts.slice(i * batchSize, i * batchSize + batchSize));
|
||||
|
||||
@@ -48,7 +50,7 @@ async function getBatchVector(source, sourceSettings, texts) {
|
||||
for (let batch of batches) {
|
||||
switch (source) {
|
||||
case 'nomicai':
|
||||
results.push(...await require('../nomicai-vectors').getNomicAIBatchVector(batch, source));
|
||||
results.push(...await require('../nomicai-vectors').getNomicAIBatchVector(batch, source, directories));
|
||||
break;
|
||||
case 'togetherai':
|
||||
case 'mistral':
|
||||
@@ -62,7 +64,7 @@ async function getBatchVector(source, sourceSettings, texts) {
|
||||
results.push(...await require('../extras-vectors').getExtrasBatchVector(batch, sourceSettings.extrasUrl, sourceSettings.extrasKey));
|
||||
break;
|
||||
case 'palm':
|
||||
results.push(...await require('../makersuite-vectors').getMakerSuiteBatchVector(batch));
|
||||
results.push(...await require('../makersuite-vectors').getMakerSuiteBatchVector(batch, directories));
|
||||
break;
|
||||
default:
|
||||
throw new Error(`Unknown vector source ${source}`);
|
||||
@@ -74,13 +76,15 @@ async function getBatchVector(source, sourceSettings, texts) {
|
||||
|
||||
/**
|
||||
* Gets the index for the vector collection
|
||||
* @param {import('../users').UserDirectoryList} directories - User directories
|
||||
* @param {string} collectionId - The collection ID
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {boolean} create - Whether to create the index if it doesn't exist
|
||||
* @returns {Promise<vectra.LocalIndex>} - The index for the collection
|
||||
*/
|
||||
async function getIndex(collectionId, source, create = true) {
|
||||
const store = new vectra.LocalIndex(path.join(process.cwd(), 'vectors', sanitize(source), sanitize(collectionId)));
|
||||
async function getIndex(directories, collectionId, source, create = true) {
|
||||
const pathToFile = path.join(directories.vectors, sanitize(source), sanitize(collectionId));
|
||||
const store = new vectra.LocalIndex(pathToFile);
|
||||
|
||||
if (create && !await store.isIndexCreated()) {
|
||||
await store.createIndex();
|
||||
@@ -91,17 +95,18 @@ async function getIndex(collectionId, source, create = true) {
|
||||
|
||||
/**
|
||||
* Inserts items into the vector collection
|
||||
* @param {import('../users').UserDirectoryList} directories - User directories
|
||||
* @param {string} collectionId - The collection ID
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
||||
* @param {{ hash: number; text: string; index: number; }[]} items - The items to insert
|
||||
*/
|
||||
async function insertVectorItems(collectionId, source, sourceSettings, items) {
|
||||
const store = await getIndex(collectionId, source);
|
||||
async function insertVectorItems(directories, collectionId, source, sourceSettings, items) {
|
||||
const store = await getIndex(directories, collectionId, source);
|
||||
|
||||
await store.beginUpdate();
|
||||
|
||||
const vectors = await getBatchVector(source, sourceSettings, items.map(x => x.text));
|
||||
const vectors = await getBatchVector(source, sourceSettings, items.map(x => x.text), directories);
|
||||
|
||||
for (let i = 0; i < items.length; i++) {
|
||||
const item = items[i];
|
||||
@@ -114,12 +119,13 @@ async function insertVectorItems(collectionId, source, sourceSettings, items) {
|
||||
|
||||
/**
|
||||
* Gets the hashes of the items in the vector collection
|
||||
* @param {import('../users').UserDirectoryList} directories - User directories
|
||||
* @param {string} collectionId - The collection ID
|
||||
* @param {string} source - The source of the vector
|
||||
* @returns {Promise<number[]>} - The hashes of the items in the collection
|
||||
*/
|
||||
async function getSavedHashes(collectionId, source) {
|
||||
const store = await getIndex(collectionId, source);
|
||||
async function getSavedHashes(directories, collectionId, source) {
|
||||
const store = await getIndex(directories, collectionId, source);
|
||||
|
||||
const items = await store.listItems();
|
||||
const hashes = items.map(x => Number(x.metadata.hash));
|
||||
@@ -129,12 +135,13 @@ async function getSavedHashes(collectionId, source) {
|
||||
|
||||
/**
|
||||
* Deletes items from the vector collection by hash
|
||||
* @param {import('../users').UserDirectoryList} directories - User directories
|
||||
* @param {string} collectionId - The collection ID
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {number[]} hashes - The hashes of the items to delete
|
||||
*/
|
||||
async function deleteVectorItems(collectionId, source, hashes) {
|
||||
const store = await getIndex(collectionId, source);
|
||||
async function deleteVectorItems(directories, collectionId, source, hashes) {
|
||||
const store = await getIndex(directories, collectionId, source);
|
||||
const items = await store.listItemsByMetadata({ hash: { '$in': hashes } });
|
||||
|
||||
await store.beginUpdate();
|
||||
@@ -155,9 +162,9 @@ async function deleteVectorItems(collectionId, source, hashes) {
|
||||
* @param {number} topK - The number of results to return
|
||||
* @returns {Promise<{hashes: number[], metadata: object[]}>} - The metadata of the items that match the search text
|
||||
*/
|
||||
async function queryCollection(collectionId, source, sourceSettings, searchText, topK) {
|
||||
const store = await getIndex(collectionId, source);
|
||||
const vector = await getVector(source, sourceSettings, searchText);
|
||||
async function queryCollection(directories, collectionId, source, sourceSettings, searchText, topK) {
|
||||
const store = await getIndex(directories, collectionId, source);
|
||||
const vector = await getVector(source, sourceSettings, searchText, directories);
|
||||
|
||||
const result = await store.queryItems(vector, topK);
|
||||
const metadata = result.map(x => x.item.metadata);
|
||||
@@ -214,7 +221,7 @@ router.post('/query', jsonParser, async (req, res) => {
|
||||
const source = String(req.body.source) || 'transformers';
|
||||
const sourceSettings = getSourceSettings(source, req);
|
||||
|
||||
const results = await queryCollection(collectionId, source, sourceSettings, searchText, topK);
|
||||
const results = await queryCollection(req.user.directories, collectionId, source, sourceSettings, searchText, topK);
|
||||
return res.json(results);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
@@ -233,7 +240,7 @@ router.post('/insert', jsonParser, async (req, res) => {
|
||||
const source = String(req.body.source) || 'transformers';
|
||||
const sourceSettings = getSourceSettings(source, req);
|
||||
|
||||
await insertVectorItems(collectionId, source, sourceSettings, items);
|
||||
await insertVectorItems(req.user.directories, collectionId, source, sourceSettings, items);
|
||||
return res.sendStatus(200);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
@@ -250,7 +257,7 @@ router.post('/list', jsonParser, async (req, res) => {
|
||||
const collectionId = String(req.body.collectionId);
|
||||
const source = String(req.body.source) || 'transformers';
|
||||
|
||||
const hashes = await getSavedHashes(collectionId, source);
|
||||
const hashes = await getSavedHashes(req.user.directories, collectionId, source);
|
||||
return res.json(hashes);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
@@ -268,7 +275,7 @@ router.post('/delete', jsonParser, async (req, res) => {
|
||||
const hashes = req.body.hashes.map(x => Number(x));
|
||||
const source = String(req.body.source) || 'transformers';
|
||||
|
||||
await deleteVectorItems(collectionId, source, hashes);
|
||||
await deleteVectorItems(req.user.directories, collectionId, source, hashes);
|
||||
return res.sendStatus(200);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
@@ -285,7 +292,7 @@ router.post('/purge', jsonParser, async (req, res) => {
|
||||
const collectionId = String(req.body.collectionId);
|
||||
|
||||
for (const source of SOURCES) {
|
||||
const index = await getIndex(collectionId, source, false);
|
||||
const index = await getIndex(req.user.directories, collectionId, source, false);
|
||||
|
||||
const exists = await index.isIndexCreated();
|
||||
|
||||
|
Reference in New Issue
Block a user