diff --git a/src/endpoints/vectors.js b/src/endpoints/vectors.js index 75c2923fd..887a730f0 100644 --- a/src/endpoints/vectors.js +++ b/src/endpoints/vectors.js @@ -30,16 +30,19 @@ async function getVector(source, sourceSettings, text) { /** * Gets the vector for the given text batch from the given source. * @param {string} source - The source of the vector + * @param {Object} sourceSettings - Settings for the source, if it needs any * @param {string[]} texts - The array of texts to get the vector for * @returns {Promise} - The array of vectors for the texts */ -async function getBatchVector(source, texts) { +async function getBatchVector(source, sourceSettings, texts) { switch (source) { case 'mistral': case 'openai': return require('../openai-vectors').getOpenAIBatchVector(texts, source); case 'transformers': return require('../embedding').getTransformersBatchVector(texts); + case 'extras': + return require('../extras-vectors').getExtrasBatchVector(texts, sourceSettings.extrasUrl, sourceSettings.extrasKey); case 'palm': return require('../makersuite-vectors').getMakerSuiteBatchVector(texts); } @@ -76,7 +79,7 @@ async function insertVectorItems(collectionId, source, sourceSettings, items) { await store.beginUpdate(); - const vectors = await getBatchVector(source, items.map(x => x.text)); + const vectors = await getBatchVector(source, sourceSettings, items.map(x => x.text)); for (let i = 0; i < items.length; i++) { const item = items[i]; diff --git a/src/extras-vectors.js b/src/extras-vectors.js index 7b91c9211..fb435369f 100644 --- a/src/extras-vectors.js +++ b/src/extras-vectors.js @@ -2,10 +2,26 @@ const fetch = require('node-fetch').default; /** * Gets the vector for the given text from SillyTavern-extras - * @param {string|Array} text - The text or texts to get the vector for + * @param {string[]} texts - The array of texts to get the vector for * @param {string} apiUrl - The Extras API URL * @param {string} - The Extras API key, or empty string if API key not enabled - * @returns {Promise} - The vector for a single text, or the array of vectors for multiple texts + * @returns {Promise} - The array of vectors for the texts + */ +async function getExtrasBatchVector(texts, apiUrl, apiKey) { + return getExtrasVector(texts, apiUrl, apiKey); // The implementation supports batches transparently. +} + +module.exports = { + getExtrasVector, + getExtrasBatchVector, +}; + +/** + * Gets the vector for the given text from SillyTavern-extras + * @param {string|string[]} text - The text or texts to get the vector for + * @param {string} apiUrl - The Extras API URL + * @param {string} - The Extras API key, or empty string if API key not enabled + * @returns {Promise|Promise} - The vector for a single text, or the array of vectors for multiple texts */ async function getExtrasVector(text, apiUrl, apiKey) { let url; @@ -33,7 +49,7 @@ async function getExtrasVector(text, apiUrl, apiKey) { method: 'POST', headers: headers, body: JSON.stringify({ - text: text, // The backend accepts {string|Array} for one or multiple text items, respectively. + text: text, // The backend accepts {string|string[]} for one or multiple text items, respectively. }), }); @@ -44,11 +60,7 @@ async function getExtrasVector(text, apiUrl, apiKey) { } const data = await response.json(); - const vector = data.embedding; // `embedding`: Array (one text item), or Array of Array (multiple text items). + const vector = data.embedding; // `embedding`: number[] (one text item), or number[][] (multiple text items). return vector; } - -module.exports = { - getExtrasVector, -};