diff --git a/public/scripts/extensions/vectors/index.js b/public/scripts/extensions/vectors/index.js
index b857998ae..ef13de22a 100644
--- a/public/scripts/extensions/vectors/index.js
+++ b/public/scripts/extensions/vectors/index.js
@@ -12,6 +12,7 @@ const settings = {
// For both
source: 'transformers',
include_wi: false,
+ model: 'togethercomputer/m2-bert-80M-32k-retrieval',
// For chats
enabled_chats: false,
@@ -440,6 +441,17 @@ function addExtrasHeaders(headers) {
});
}
+/**
+ * Add headers for the Extras API source.
+ * @param {object} headers Headers object
+ */
+function addTogetherAiHeaders(headers) {
+ Object.assign(headers, {
+ 'X-Togetherai-Model': extension_settings.vectors.model,
+ });
+}
+
+
/**
* Inserts vector items into a collection
* @param {string} collectionId - The collection to insert into
@@ -449,7 +461,8 @@ function addExtrasHeaders(headers) {
async function insertVectorItems(collectionId, items) {
if (settings.source === 'openai' && !secret_state[SECRET_KEYS.OPENAI] ||
settings.source === 'palm' && !secret_state[SECRET_KEYS.MAKERSUITE] ||
- settings.source === 'mistral' && !secret_state[SECRET_KEYS.MISTRALAI]) {
+ settings.source === 'mistral' && !secret_state[SECRET_KEYS.MISTRALAI] ||
+ settings.source === 'togetherai' && !secret_state[SECRET_KEYS.TOGETHERAI]) {
throw new Error('Vectors: API key missing', { cause: 'api_key_missing' });
}
@@ -460,6 +473,8 @@ async function insertVectorItems(collectionId, items) {
const headers = getRequestHeaders();
if (settings.source === 'extras') {
addExtrasHeaders(headers);
+ } else if (settings.source === 'togetherai') {
+ addTogetherAiHeaders(headers);
}
const response = await fetch('/api/vector/insert', {
@@ -509,6 +524,8 @@ async function queryCollection(collectionId, searchText, topK) {
const headers = getRequestHeaders();
if (settings.source === 'extras') {
addExtrasHeaders(headers);
+ } else if (settings.source === 'togetherai') {
+ addTogetherAiHeaders(headers);
}
const response = await fetch('/api/vector/query', {
@@ -526,8 +543,7 @@ async function queryCollection(collectionId, searchText, topK) {
throw new Error(`Failed to query collection ${collectionId}`);
}
- const results = await response.json();
- return results;
+ return await response.json();
}
/**
@@ -617,6 +633,13 @@ jQuery(async () => {
}
Object.assign(settings, extension_settings.vectors);
+
+ if (settings.source === 'togetherai') {
+ $('#vectorsModel').show();
+ } else {
+ $('#vectorsModel').hide();
+ }
+
// Migrate from TensorFlow to Transformers
settings.source = settings.source !== 'local' ? settings.source : 'transformers';
$('#extensions_settings2').append(renderExtensionTemplate(MODULE_NAME, 'settings'));
@@ -634,6 +657,17 @@ jQuery(async () => {
});
$('#vectors_source').val(settings.source).on('change', () => {
settings.source = String($('#vectors_source').val());
+ if (settings.source === 'togetherai') {
+ $('#vectorsModel').show();
+ } else {
+ $('#vectorsModel').hide();
+ }
+ Object.assign(extension_settings.vectors, settings);
+ saveSettingsDebounced();
+ });
+
+ $('#vectors_model').val(settings.model).on('change', () => {
+ settings.model = String($('#vectors_model').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
diff --git a/public/scripts/extensions/vectors/settings.html b/public/scripts/extensions/vectors/settings.html
index 26e0fb731..fb26b14ff 100644
--- a/public/scripts/extensions/vectors/settings.html
+++ b/public/scripts/extensions/vectors/settings.html
@@ -15,6 +15,22 @@
+
+
+
+
+
+
diff --git a/src/endpoints/vectors.js b/src/endpoints/vectors.js
index b2ff9e2d0..c3d0a2e5b 100644
--- a/src/endpoints/vectors.js
+++ b/src/endpoints/vectors.js
@@ -5,7 +5,7 @@ const sanitize = require('sanitize-filename');
const { jsonParser } = require('../express-common');
// Don't forget to add new sources to the SOURCES array
-const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm'];
+const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm', 'togetherai'];
/**
* Gets the vector for the given text from the given source.
@@ -16,9 +16,10 @@ const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm'];
*/
async function getVector(source, sourceSettings, text) {
switch (source) {
+ case 'togetherai':
case 'mistral':
case 'openai':
- return require('../openai-vectors').getOpenAIVector(text, source);
+ return require('../openai-vectors').getOpenAIVector(text, source, sourceSettings.model);
case 'transformers':
return require('../embedding').getTransformersVector(text);
case 'extras':
@@ -44,9 +45,10 @@ async function getBatchVector(source, sourceSettings, texts) {
let results = [];
for (let batch of batches) {
switch (source) {
+ case 'togetherai':
case 'mistral':
case 'openai':
- results.push(...await require('../openai-vectors').getOpenAIBatchVector(batch, source));
+ results.push(...await require('../openai-vectors').getOpenAIBatchVector(batch, source, sourceSettings.model));
break;
case 'transformers':
results.push(...await require('../embedding').getTransformersBatchVector(batch));
@@ -165,19 +167,26 @@ async function queryCollection(collectionId, source, sourceSettings, searchText,
* @returns {object} - An object that can be used as `sourceSettings` in functions that take that parameter.
*/
function getSourceSettings(source, request) {
- // Extras API settings to connect to the Extras embeddings provider
- let extrasUrl = '';
- let extrasKey = '';
- if (source === 'extras') {
- extrasUrl = String(request.headers['x-extras-url']);
- extrasKey = String(request.headers['x-extras-key']);
- }
+ if (source === 'togetherai') {
+ let model = String(request.headers['x-togetherai-model']);
- const sourceSettings = {
- extrasUrl: extrasUrl,
- extrasKey: extrasKey,
- };
- return sourceSettings;
+ return {
+ model: model,
+ };
+ } else {
+ // Extras API settings to connect to the Extras embeddings provider
+ let extrasUrl = '';
+ let extrasKey = '';
+ if (source === 'extras') {
+ extrasUrl = String(request.headers['x-extras-url']);
+ extrasKey = String(request.headers['x-extras-key']);
+ }
+
+ return {
+ extrasUrl: extrasUrl,
+ extrasKey: extrasKey,
+ };
+ }
}
const router = express.Router();
diff --git a/src/openai-vectors.js b/src/openai-vectors.js
index 2a0281e9c..222bc31aa 100644
--- a/src/openai-vectors.js
+++ b/src/openai-vectors.js
@@ -2,6 +2,11 @@ const fetch = require('node-fetch').default;
const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
const SOURCES = {
+ 'togetherai': {
+ secretKey: SECRET_KEYS.TOGETHERAI,
+ url: 'api.together.xyz',
+ model: 'togethercomputer/m2-bert-80M-32k-retrieval',
+ },
'mistral': {
secretKey: SECRET_KEYS.MISTRALAI,
url: 'api.mistral.ai',
@@ -18,9 +23,10 @@ const SOURCES = {
* Gets the vector for the given text batch from an OpenAI compatible endpoint.
* @param {string[]} texts - The array of texts to get the vector for
* @param {string} source - The source of the vector
+ * @param model
* @returns {Promise} - The array of vectors for the texts
*/
-async function getOpenAIBatchVector(texts, source) {
+async function getOpenAIBatchVector(texts, source, model = '') {
const config = SOURCES[source];
if (!config) {
@@ -44,7 +50,7 @@ async function getOpenAIBatchVector(texts, source) {
},
body: JSON.stringify({
input: texts,
- model: config.model,
+ model: source === 'togetherai' && model !== '' ? model : config.model,
}),
});
@@ -72,10 +78,11 @@ async function getOpenAIBatchVector(texts, source) {
* Gets the vector for the given text from an OpenAI compatible endpoint.
* @param {string} text - The text to get the vector for
* @param {string} source - The source of the vector
+ * @param model
* @returns {Promise} - The vector for the text
*/
-async function getOpenAIVector(text, source) {
- const vectors = await getOpenAIBatchVector([text], source);
+async function getOpenAIVector(text, source, model = '') {
+ const vectors = await getOpenAIBatchVector([text], source, model);
return vectors[0];
}