#1671 Add batch vectorization

This commit is contained in:
Cohee 2024-01-24 13:56:13 +02:00
parent 3d2c8bf674
commit cfdf43a26e
4 changed files with 77 additions and 15 deletions

View File

@ -1,6 +1,7 @@
const TASK = 'feature-extraction';
/**
* Gets the vectorized text in form of an array of numbers.
* @param {string} text - The text to vectorize
* @returns {Promise<number[]>} - The vectorized text in form of an array of numbers
*/
@ -12,6 +13,20 @@ async function getTransformersVector(text) {
return vector;
}
/**
* Gets the vectorized texts in form of an array of arrays of numbers.
* @param {string[]} texts - The texts to vectorize
* @returns {Promise<number[][]>} - The vectorized texts in form of an array of arrays of numbers
*/
async function getTransformersBatchVector(texts) {
const result = [];
for (const text of texts) {
result.push(await getTransformersVector(text));
}
return result;
}
module.exports = {
getTransformersVector,
getTransformersBatchVector,
};

View File

@ -24,6 +24,26 @@ async function getVector(source, text) {
throw new Error(`Unknown vector source ${source}`);
}
/**
* Gets the vector for the given text batch from the given source.
* @param {string} source - The source of the vector
* @param {string[]} texts - The array of texts to get the vector for
* @returns {Promise<number[][]>} - The array of vectors for the texts
*/
async function getBatchVector(source, texts) {
switch (source) {
case 'mistral':
case 'openai':
return require('../openai-vectors').getOpenAIBatchVector(texts, source);
case 'transformers':
return require('../embedding').getTransformersBatchVector(texts);
case 'palm':
return require('../makersuite-vectors').getMakerSuiteBatchVector(texts);
}
throw new Error(`Unknown vector source ${source}`);
}
/**
* Gets the index for the vector collection
* @param {string} collectionId - The collection ID
@ -52,12 +72,12 @@ async function insertVectorItems(collectionId, source, items) {
await store.beginUpdate();
for (const item of items) {
const text = item.text;
const hash = item.hash;
const index = item.index;
const vector = await getVector(source, text);
await store.upsertItem({ vector: vector, metadata: { hash, text, index } });
const vectors = await getBatchVector(source, items.map(x => x.text));
for (let i = 0; i < items.length; i++) {
const item = items[i];
const vector = vectors[i];
await store.upsertItem({ vector: vector, metadata: { hash: item.hash, text: item.text, index: item.index } });
}
await store.endUpdate();

View File

@ -1,6 +1,17 @@
const fetch = require('node-fetch').default;
const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
/**
* Gets the vector for the given text from gecko model
* @param {string[]} texts - The array of texts to get the vector for
* @returns {Promise<number[][]>} - The array of vectors for the texts
*/
async function getMakerSuiteBatchVector(texts) {
const promises = texts.map(text => getMakerSuiteVector(text));
const vectors = await Promise.all(promises);
return vectors;
}
/**
* Gets the vector for the given text from PaLM gecko model
* @param {string} text - The text to get the vector for
@ -40,4 +51,5 @@ async function getMakerSuiteVector(text) {
module.exports = {
getMakerSuiteVector,
getMakerSuiteBatchVector,
};

View File

@ -3,7 +3,7 @@ const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
const SOURCES = {
'mistral': {
secretKey: SECRET_KEYS.MISTRAL,
secretKey: SECRET_KEYS.MISTRALAI,
url: 'api.mistral.ai',
model: 'mistral-embed',
},
@ -15,12 +15,12 @@ const SOURCES = {
};
/**
* Gets the vector for the given text from an OpenAI compatible endpoint.
* @param {string} text - The text to get the vector for
* Gets the vector for the given text batch from an OpenAI compatible endpoint.
* @param {string[]} texts - The array of texts to get the vector for
* @param {string} source - The source of the vector
* @returns {Promise<number[]>} - The vector for the text
* @returns {Promise<number[][]>} - The array of vectors for the texts
*/
async function getOpenAIVector(text, source) {
async function getOpenAIBatchVector(texts, source) {
const config = SOURCES[source];
if (!config) {
@ -43,7 +43,7 @@ async function getOpenAIVector(text, source) {
Authorization: `Bearer ${key}`,
},
body: JSON.stringify({
input: text,
input: texts,
model: config.model,
}),
});
@ -55,16 +55,31 @@ async function getOpenAIVector(text, source) {
}
const data = await response.json();
const vector = data?.data[0]?.embedding;
if (!Array.isArray(vector)) {
if (!Array.isArray(data?.data)) {
console.log('API response was not an array');
throw new Error('API response was not an array');
}
return vector;
// Sort data by x.index to ensure the order is correct
data.data.sort((a, b) => a.index - b.index);
const vectors = data.data.map(x => x.embedding);
return vectors;
}
/**
* Gets the vector for the given text from an OpenAI compatible endpoint.
* @param {string} text - The text to get the vector for
* @param {string} source - The source of the vector
* @returns {Promise<number[]>} - The vector for the text
*/
async function getOpenAIVector(text, source) {
const vectors = await getOpenAIBatchVector([text], source);
return vectors[0];
}
module.exports = {
getOpenAIVector,
getOpenAIBatchVector,
};