diff --git a/src/endpoints/togetherai.js b/src/endpoints/togetherai.js new file mode 100644 index 000000000..309eff1f3 --- /dev/null +++ b/src/endpoints/togetherai.js @@ -0,0 +1,4 @@ +/** + * sends a request to the together AI api + * + */ \ No newline at end of file diff --git a/src/endpoints/vectors.js b/src/endpoints/vectors.js index 387803ccb..eb3efc88b 100644 --- a/src/endpoints/vectors.js +++ b/src/endpoints/vectors.js @@ -12,8 +12,9 @@ const { jsonParser } = require('../express-common'); */ async function getVector(source, text) { switch (source) { + case 'togetherai': case 'openai': - return require('../openai-vectors').getOpenAIVector(text); + return require('../openai-vectors').getOpenAIVector(text, source); case 'transformers': return require('../embedding').getTransformersVector(text); case 'palm': diff --git a/src/openai-vectors.js b/src/openai-vectors.js index ecb245065..08bd15bf0 100644 --- a/src/openai-vectors.js +++ b/src/openai-vectors.js @@ -2,19 +2,40 @@ const fetch = require('node-fetch').default; const { SECRET_KEYS, readSecret } = require('./endpoints/secrets'); /** - * Gets the vector for the given text from OpenAI ada model + * Gets the vector for the given text from an OpenAI compatible endpoint. * @param {string} text - The text to get the vector for + * @param {string} source - The source of the vector * @returns {Promise} - The vector for the text */ -async function getOpenAIVector(text) { - const key = readSecret(SECRET_KEYS.OPENAI); +async function getOpenAIVector(text, source) { + + // dictionary of sources to endpoints with source as key and endpoint, model and secret key as value + const endpoints = { + 'togetherai': { + endpoint: 'https://api.togetherai.xyz/v1/embeddings', // is this correct? + model: 'togethercomputer/GPT-NeoXT-Chat-Base-20B', + secret: SECRET_KEYS.TOGETHERAI, + }, + 'openai': { + endpoint: 'https://api.openai.com/v1/embeddings', + model: 'text-embedding-ada-002', + secret: SECRET_KEYS.OPENAI, + }, + 'mistral': { + endpoint: 'https://api.mistral.ai/v1/embeddings', + model: 'mistral-embed', + secret: SECRET_KEYS.MISTRAL, + }, + }; + + const key = readSecret(endpoints[source].secret); if (!key) { - console.log('No OpenAI key found'); - throw new Error('No OpenAI key found'); + console.log('No %s key found.', source); + throw new Error('No ${source} key found.'); } - const response = await fetch('https://api.openai.com/v1/embeddings', { + const response = await fetch(endpoints[source].endpoint, { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -22,22 +43,22 @@ async function getOpenAIVector(text) { }, body: JSON.stringify({ input: text, - model: 'text-embedding-ada-002', + model: endpoints[source].model, }), }); if (!response.ok) { const text = await response.text(); - console.log('OpenAI request failed', response.statusText, text); - throw new Error('OpenAI request failed'); + console.log('${source} request failed', response.statusText, text); + throw new Error('${source} request failed'); } const data = await response.json(); const vector = data?.data[0]?.embedding; if (!Array.isArray(vector)) { - console.log('OpenAI response was not an array'); - throw new Error('OpenAI response was not an array'); + console.log('${source} response was not an array'); + throw new Error('${source} response was not an array'); } return vector;