mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2024-12-12 09:26:33 +01:00
#1671 Add batch vectorization
This commit is contained in:
parent
3d2c8bf674
commit
cfdf43a26e
@ -1,6 +1,7 @@
|
||||
const TASK = 'feature-extraction';
|
||||
|
||||
/**
|
||||
* Gets the vectorized text in form of an array of numbers.
|
||||
* @param {string} text - The text to vectorize
|
||||
* @returns {Promise<number[]>} - The vectorized text in form of an array of numbers
|
||||
*/
|
||||
@ -12,6 +13,20 @@ async function getTransformersVector(text) {
|
||||
return vector;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the vectorized texts in form of an array of arrays of numbers.
|
||||
* @param {string[]} texts - The texts to vectorize
|
||||
* @returns {Promise<number[][]>} - The vectorized texts in form of an array of arrays of numbers
|
||||
*/
|
||||
async function getTransformersBatchVector(texts) {
|
||||
const result = [];
|
||||
for (const text of texts) {
|
||||
result.push(await getTransformersVector(text));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getTransformersVector,
|
||||
getTransformersBatchVector,
|
||||
};
|
||||
|
@ -24,6 +24,26 @@ async function getVector(source, text) {
|
||||
throw new Error(`Unknown vector source ${source}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text batch from the given source.
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {string[]} texts - The array of texts to get the vector for
|
||||
* @returns {Promise<number[][]>} - The array of vectors for the texts
|
||||
*/
|
||||
async function getBatchVector(source, texts) {
|
||||
switch (source) {
|
||||
case 'mistral':
|
||||
case 'openai':
|
||||
return require('../openai-vectors').getOpenAIBatchVector(texts, source);
|
||||
case 'transformers':
|
||||
return require('../embedding').getTransformersBatchVector(texts);
|
||||
case 'palm':
|
||||
return require('../makersuite-vectors').getMakerSuiteBatchVector(texts);
|
||||
}
|
||||
|
||||
throw new Error(`Unknown vector source ${source}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the index for the vector collection
|
||||
* @param {string} collectionId - The collection ID
|
||||
@ -52,12 +72,12 @@ async function insertVectorItems(collectionId, source, items) {
|
||||
|
||||
await store.beginUpdate();
|
||||
|
||||
for (const item of items) {
|
||||
const text = item.text;
|
||||
const hash = item.hash;
|
||||
const index = item.index;
|
||||
const vector = await getVector(source, text);
|
||||
await store.upsertItem({ vector: vector, metadata: { hash, text, index } });
|
||||
const vectors = await getBatchVector(source, items.map(x => x.text));
|
||||
|
||||
for (let i = 0; i < items.length; i++) {
|
||||
const item = items[i];
|
||||
const vector = vectors[i];
|
||||
await store.upsertItem({ vector: vector, metadata: { hash: item.hash, text: item.text, index: item.index } });
|
||||
}
|
||||
|
||||
await store.endUpdate();
|
||||
|
@ -1,6 +1,17 @@
|
||||
const fetch = require('node-fetch').default;
|
||||
const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from gecko model
|
||||
* @param {string[]} texts - The array of texts to get the vector for
|
||||
* @returns {Promise<number[][]>} - The array of vectors for the texts
|
||||
*/
|
||||
async function getMakerSuiteBatchVector(texts) {
|
||||
const promises = texts.map(text => getMakerSuiteVector(text));
|
||||
const vectors = await Promise.all(promises);
|
||||
return vectors;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from PaLM gecko model
|
||||
* @param {string} text - The text to get the vector for
|
||||
@ -40,4 +51,5 @@ async function getMakerSuiteVector(text) {
|
||||
|
||||
module.exports = {
|
||||
getMakerSuiteVector,
|
||||
getMakerSuiteBatchVector,
|
||||
};
|
||||
|
@ -3,7 +3,7 @@ const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
|
||||
|
||||
const SOURCES = {
|
||||
'mistral': {
|
||||
secretKey: SECRET_KEYS.MISTRAL,
|
||||
secretKey: SECRET_KEYS.MISTRALAI,
|
||||
url: 'api.mistral.ai',
|
||||
model: 'mistral-embed',
|
||||
},
|
||||
@ -15,12 +15,12 @@ const SOURCES = {
|
||||
};
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from an OpenAI compatible endpoint.
|
||||
* @param {string} text - The text to get the vector for
|
||||
* Gets the vector for the given text batch from an OpenAI compatible endpoint.
|
||||
* @param {string[]} texts - The array of texts to get the vector for
|
||||
* @param {string} source - The source of the vector
|
||||
* @returns {Promise<number[]>} - The vector for the text
|
||||
* @returns {Promise<number[][]>} - The array of vectors for the texts
|
||||
*/
|
||||
async function getOpenAIVector(text, source) {
|
||||
async function getOpenAIBatchVector(texts, source) {
|
||||
const config = SOURCES[source];
|
||||
|
||||
if (!config) {
|
||||
@ -43,7 +43,7 @@ async function getOpenAIVector(text, source) {
|
||||
Authorization: `Bearer ${key}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
input: text,
|
||||
input: texts,
|
||||
model: config.model,
|
||||
}),
|
||||
});
|
||||
@ -55,16 +55,31 @@ async function getOpenAIVector(text, source) {
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
const vector = data?.data[0]?.embedding;
|
||||
|
||||
if (!Array.isArray(vector)) {
|
||||
if (!Array.isArray(data?.data)) {
|
||||
console.log('API response was not an array');
|
||||
throw new Error('API response was not an array');
|
||||
}
|
||||
|
||||
return vector;
|
||||
// Sort data by x.index to ensure the order is correct
|
||||
data.data.sort((a, b) => a.index - b.index);
|
||||
|
||||
const vectors = data.data.map(x => x.embedding);
|
||||
return vectors;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from an OpenAI compatible endpoint.
|
||||
* @param {string} text - The text to get the vector for
|
||||
* @param {string} source - The source of the vector
|
||||
* @returns {Promise<number[]>} - The vector for the text
|
||||
*/
|
||||
async function getOpenAIVector(text, source) {
|
||||
const vectors = await getOpenAIBatchVector([text], source);
|
||||
return vectors[0];
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getOpenAIVector,
|
||||
getOpenAIBatchVector,
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user