Implement TogetherAI as vectorization provider

This commit is contained in:
Kristan Schlikow
2024-03-01 23:52:49 +01:00
parent 95c49029f7
commit adfb9c5097
4 changed files with 88 additions and 22 deletions

View File

@@ -5,7 +5,7 @@ const sanitize = require('sanitize-filename');
const { jsonParser } = require('../express-common');
// Don't forget to add new sources to the SOURCES array
const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm'];
const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm', 'togetherai'];
/**
* Gets the vector for the given text from the given source.
@@ -16,9 +16,10 @@ const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm'];
*/
async function getVector(source, sourceSettings, text) {
switch (source) {
case 'togetherai':
case 'mistral':
case 'openai':
return require('../openai-vectors').getOpenAIVector(text, source);
return require('../openai-vectors').getOpenAIVector(text, source, sourceSettings.model);
case 'transformers':
return require('../embedding').getTransformersVector(text);
case 'extras':
@@ -44,9 +45,10 @@ async function getBatchVector(source, sourceSettings, texts) {
let results = [];
for (let batch of batches) {
switch (source) {
case 'togetherai':
case 'mistral':
case 'openai':
results.push(...await require('../openai-vectors').getOpenAIBatchVector(batch, source));
results.push(...await require('../openai-vectors').getOpenAIBatchVector(batch, source, sourceSettings.model));
break;
case 'transformers':
results.push(...await require('../embedding').getTransformersBatchVector(batch));
@@ -165,19 +167,26 @@ async function queryCollection(collectionId, source, sourceSettings, searchText,
* @returns {object} - An object that can be used as `sourceSettings` in functions that take that parameter.
*/
function getSourceSettings(source, request) {
// Extras API settings to connect to the Extras embeddings provider
let extrasUrl = '';
let extrasKey = '';
if (source === 'extras') {
extrasUrl = String(request.headers['x-extras-url']);
extrasKey = String(request.headers['x-extras-key']);
}
if (source === 'togetherai') {
let model = String(request.headers['x-togetherai-model']);
const sourceSettings = {
extrasUrl: extrasUrl,
extrasKey: extrasKey,
};
return sourceSettings;
return {
model: model,
};
} else {
// Extras API settings to connect to the Extras embeddings provider
let extrasUrl = '';
let extrasKey = '';
if (source === 'extras') {
extrasUrl = String(request.headers['x-extras-url']);
extrasKey = String(request.headers['x-extras-key']);
}
return {
extrasUrl: extrasUrl,
extrasKey: extrasKey,
};
}
}
const router = express.Router();