Add Cohere as embedding source

This commit is contained in:
Cohee
2024-04-19 00:07:12 +03:00
parent b69493d252
commit 25cb598694
9 changed files with 147 additions and 32 deletions

View File

@@ -12,23 +12,26 @@ const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm', 'togethe
* @param {string} source - The source of the vector
* @param {Object} sourceSettings - Settings for the source, if it needs any
* @param {string} text - The text to get the vector for
* @param {boolean} isQuery - If the text is a query for embedding search
* @param {import('../users').UserDirectoryList} directories - The directories object for the user
* @returns {Promise<number[]>} - The vector for the text
*/
async function getVector(source, sourceSettings, text, directories) {
async function getVector(source, sourceSettings, text, isQuery, directories) {
switch (source) {
case 'nomicai':
return require('../nomicai-vectors').getNomicAIVector(text, source, directories);
return require('../vectors/nomicai-vectors').getNomicAIVector(text, source, directories);
case 'togetherai':
case 'mistral':
case 'openai':
return require('../openai-vectors').getOpenAIVector(text, source, directories, sourceSettings.model);
return require('../vectors/openai-vectors').getOpenAIVector(text, source, directories, sourceSettings.model);
case 'transformers':
return require('../embedding').getTransformersVector(text);
return require('../vectors/embedding').getTransformersVector(text);
case 'extras':
return require('../extras-vectors').getExtrasVector(text, sourceSettings.extrasUrl, sourceSettings.extrasKey);
return require('../vectors/extras-vectors').getExtrasVector(text, sourceSettings.extrasUrl, sourceSettings.extrasKey);
case 'palm':
return require('../makersuite-vectors').getMakerSuiteVector(text, directories);
return require('../vectors/makersuite-vectors').getMakerSuiteVector(text, directories);
case 'cohere':
return require('../vectors/cohere-vectors').getCohereVector(text, isQuery, directories, sourceSettings.model);
}
throw new Error(`Unknown vector source ${source}`);
@@ -39,10 +42,11 @@ async function getVector(source, sourceSettings, text, directories) {
* @param {string} source - The source of the vector
* @param {Object} sourceSettings - Settings for the source, if it needs any
* @param {string[]} texts - The array of texts to get the vector for
* @param {boolean} isQuery - If the text is a query for embedding search
* @param {import('../users').UserDirectoryList} directories - The directories object for the user
* @returns {Promise<number[][]>} - The array of vectors for the texts
*/
async function getBatchVector(source, sourceSettings, texts, directories) {
async function getBatchVector(source, sourceSettings, texts, isQuery, directories) {
const batchSize = 10;
const batches = Array(Math.ceil(texts.length / batchSize)).fill(undefined).map((_, i) => texts.slice(i * batchSize, i * batchSize + batchSize));
@@ -50,21 +54,24 @@ async function getBatchVector(source, sourceSettings, texts, directories) {
for (let batch of batches) {
switch (source) {
case 'nomicai':
results.push(...await require('../nomicai-vectors').getNomicAIBatchVector(batch, source, directories));
results.push(...await require('../vectors/nomicai-vectors').getNomicAIBatchVector(batch, source, directories));
break;
case 'togetherai':
case 'mistral':
case 'openai':
results.push(...await require('../openai-vectors').getOpenAIBatchVector(batch, source, directories, sourceSettings.model));
results.push(...await require('../vectors/openai-vectors').getOpenAIBatchVector(batch, source, directories, sourceSettings.model));
break;
case 'transformers':
results.push(...await require('../embedding').getTransformersBatchVector(batch));
results.push(...await require('../vectors/embedding').getTransformersBatchVector(batch));
break;
case 'extras':
results.push(...await require('../extras-vectors').getExtrasBatchVector(batch, sourceSettings.extrasUrl, sourceSettings.extrasKey));
results.push(...await require('../vectors/extras-vectors').getExtrasBatchVector(batch, sourceSettings.extrasUrl, sourceSettings.extrasKey));
break;
case 'palm':
results.push(...await require('../makersuite-vectors').getMakerSuiteBatchVector(batch, directories));
results.push(...await require('../vectors/makersuite-vectors').getMakerSuiteBatchVector(batch, directories));
break;
case 'cohere':
results.push(...await require('../vectors/cohere-vectors').getCohereBatchVector(batch, isQuery, directories, sourceSettings.model));
break;
default:
throw new Error(`Unknown vector source ${source}`);
@@ -106,7 +113,7 @@ async function insertVectorItems(directories, collectionId, source, sourceSettin
await store.beginUpdate();
const vectors = await getBatchVector(source, sourceSettings, items.map(x => x.text), directories);
const vectors = await getBatchVector(source, sourceSettings, items.map(x => x.text), false, directories);
for (let i = 0; i < items.length; i++) {
const item = items[i];
@@ -165,7 +172,7 @@ async function deleteVectorItems(directories, collectionId, source, hashes) {
*/
async function queryCollection(directories, collectionId, source, sourceSettings, searchText, topK) {
const store = await getIndex(directories, collectionId, source);
const vector = await getVector(source, sourceSettings, searchText, directories);
const vector = await getVector(source, sourceSettings, searchText, true, directories);
const result = await store.queryItems(vector, topK);
const metadata = result.map(x => x.item.metadata);
@@ -184,7 +191,7 @@ async function queryCollection(directories, collectionId, source, sourceSettings
* @returns {Promise<Record<string, { hashes: number[], metadata: object[] }>>} - The top K results from each collection
*/
async function multiQueryCollection(directories, collectionIds, source, sourceSettings, searchText, topK) {
const vector = await getVector(source, sourceSettings, searchText, directories);
const vector = await getVector(source, sourceSettings, searchText, true, directories);
const results = [];
for (const collectionId of collectionIds) {
@@ -223,18 +230,24 @@ async function multiQueryCollection(directories, collectionIds, source, sourceSe
*/
function getSourceSettings(source, request) {
if (source === 'togetherai') {
let model = String(request.headers['x-togetherai-model']);
const model = String(request.headers['x-togetherai-model']);
return {
model: model,
};
} else if (source === 'openai') {
let model = String(request.headers['x-openai-model']);
const model = String(request.headers['x-openai-model']);
return {
model: model,
};
} else {
} else if (source === 'cohere') {
const model = String(request.headers['x-cohere-model']);
return {
model: model,
};
}else {
// Extras API settings to connect to the Extras embeddings provider
let extrasUrl = '';
let extrasKey = '';