diff --git a/public/scripts/extensions/vectors/index.js b/public/scripts/extensions/vectors/index.js index 214b1d887..9e8777333 100644 --- a/public/scripts/extensions/vectors/index.js +++ b/public/scripts/extensions/vectors/index.js @@ -394,7 +394,8 @@ async function getSavedHashes(collectionId) { */ async function insertVectorItems(collectionId, items) { if (settings.source === 'openai' && !secret_state[SECRET_KEYS.OPENAI] || - settings.source === 'palm' && !secret_state[SECRET_KEYS.MAKERSUITE]) { + settings.source === 'palm' && !secret_state[SECRET_KEYS.MAKERSUITE] || + settings.source === 'mistral' && !secret_state[SECRET_KEYS.MISTRALAI]) { throw new Error('Vectors: API key missing', { cause: 'api_key_missing' }); } diff --git a/public/scripts/extensions/vectors/settings.html b/public/scripts/extensions/vectors/settings.html index cb904b8a3..b1d74c83d 100644 --- a/public/scripts/extensions/vectors/settings.html +++ b/public/scripts/extensions/vectors/settings.html @@ -13,6 +13,7 @@ + diff --git a/src/endpoints/vectors.js b/src/endpoints/vectors.js index e49d157fa..45d4d55a6 100644 --- a/src/endpoints/vectors.js +++ b/src/endpoints/vectors.js @@ -12,8 +12,9 @@ const { jsonParser } = require('../express-common'); */ async function getVector(source, text) { switch (source) { + case 'mistral': case 'openai': - return require('../openai-vectors').getOpenAIVector(text); + return require('../openai-vectors').getOpenAIVector(text, source); case 'transformers': return require('../embedding').getTransformersVector(text); case 'palm': diff --git a/src/openai-vectors.js b/src/openai-vectors.js index ecb245065..40c54ae2f 100644 --- a/src/openai-vectors.js +++ b/src/openai-vectors.js @@ -2,19 +2,21 @@ const fetch = require('node-fetch').default; const { SECRET_KEYS, readSecret } = require('./endpoints/secrets'); /** - * Gets the vector for the given text from OpenAI ada model + * Gets the vector for the given text from an OpenAI compatible endpoint. * @param {string} text - The text to get the vector for * @returns {Promise} - The vector for the text */ -async function getOpenAIVector(text) { - const key = readSecret(SECRET_KEYS.OPENAI); +async function getOpenAIVector(text, source) { + const isMistral = source === 'mistral'; + const key = readSecret(isMistral ? SECRET_KEYS.MISTRALAI : SECRET_KEYS.OPENAI); if (!key) { console.log('No OpenAI key found'); throw new Error('No OpenAI key found'); } - const response = await fetch('https://api.openai.com/v1/embeddings', { + const url = isMistral ? 'api.mistral.ai' : 'api.openai.com'; + const response = await fetch(`https://${url}/v1/embeddings`, { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -22,7 +24,7 @@ async function getOpenAIVector(text) { }, body: JSON.stringify({ input: text, - model: 'text-embedding-ada-002', + model: isMistral ? 'mistral-embed' : 'text-embedding-ada-002', }), });