Split vector batches into 10s

This commit is contained in:
Cohee
2024-02-01 11:02:47 +02:00
parent 7618133d6d
commit 695b438c0d

View File

@@ -35,19 +35,31 @@ async function getVector(source, sourceSettings, text) {
* @returns {Promise<number[][]>} - The array of vectors for the texts * @returns {Promise<number[][]>} - The array of vectors for the texts
*/ */
async function getBatchVector(source, sourceSettings, texts) { async function getBatchVector(source, sourceSettings, texts) {
switch (source) { const batchSize = 10;
case 'mistral': const batches = Array(Math.ceil(texts.length / batchSize)).fill(undefined).map((_, i) => texts.slice(i * batchSize, i * batchSize + batchSize));
case 'openai':
return require('../openai-vectors').getOpenAIBatchVector(texts, source); let results = [];
case 'transformers': for (let batch of batches) {
return require('../embedding').getTransformersBatchVector(texts); switch (source) {
case 'extras': case 'mistral':
return require('../extras-vectors').getExtrasBatchVector(texts, sourceSettings.extrasUrl, sourceSettings.extrasKey); case 'openai':
case 'palm': results.push(...await require('../openai-vectors').getOpenAIBatchVector(batch, source));
return require('../makersuite-vectors').getMakerSuiteBatchVector(texts); break;
case 'transformers':
results.push(...await require('../embedding').getTransformersBatchVector(batch));
break;
case 'extras':
results.push(...await require('../extras-vectors').getExtrasBatchVector(batch, sourceSettings.extrasUrl, sourceSettings.extrasKey));
break;
case 'palm':
results.push(...await require('../makersuite-vectors').getMakerSuiteBatchVector(batch));
break;
default:
throw new Error(`Unknown vector source ${source}`);
}
} }
throw new Error(`Unknown vector source ${source}`); return results;
} }
/** /**