mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-06-05 21:59:27 +02:00
Add Cohere as embedding source
This commit is contained in:
@@ -12,23 +12,26 @@ const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm', 'togethe
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
||||
* @param {string} text - The text to get the vector for
|
||||
* @param {boolean} isQuery - If the text is a query for embedding search
|
||||
* @param {import('../users').UserDirectoryList} directories - The directories object for the user
|
||||
* @returns {Promise<number[]>} - The vector for the text
|
||||
*/
|
||||
async function getVector(source, sourceSettings, text, directories) {
|
||||
async function getVector(source, sourceSettings, text, isQuery, directories) {
|
||||
switch (source) {
|
||||
case 'nomicai':
|
||||
return require('../nomicai-vectors').getNomicAIVector(text, source, directories);
|
||||
return require('../vectors/nomicai-vectors').getNomicAIVector(text, source, directories);
|
||||
case 'togetherai':
|
||||
case 'mistral':
|
||||
case 'openai':
|
||||
return require('../openai-vectors').getOpenAIVector(text, source, directories, sourceSettings.model);
|
||||
return require('../vectors/openai-vectors').getOpenAIVector(text, source, directories, sourceSettings.model);
|
||||
case 'transformers':
|
||||
return require('../embedding').getTransformersVector(text);
|
||||
return require('../vectors/embedding').getTransformersVector(text);
|
||||
case 'extras':
|
||||
return require('../extras-vectors').getExtrasVector(text, sourceSettings.extrasUrl, sourceSettings.extrasKey);
|
||||
return require('../vectors/extras-vectors').getExtrasVector(text, sourceSettings.extrasUrl, sourceSettings.extrasKey);
|
||||
case 'palm':
|
||||
return require('../makersuite-vectors').getMakerSuiteVector(text, directories);
|
||||
return require('../vectors/makersuite-vectors').getMakerSuiteVector(text, directories);
|
||||
case 'cohere':
|
||||
return require('../vectors/cohere-vectors').getCohereVector(text, isQuery, directories, sourceSettings.model);
|
||||
}
|
||||
|
||||
throw new Error(`Unknown vector source ${source}`);
|
||||
@@ -39,10 +42,11 @@ async function getVector(source, sourceSettings, text, directories) {
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
||||
* @param {string[]} texts - The array of texts to get the vector for
|
||||
* @param {boolean} isQuery - If the text is a query for embedding search
|
||||
* @param {import('../users').UserDirectoryList} directories - The directories object for the user
|
||||
* @returns {Promise<number[][]>} - The array of vectors for the texts
|
||||
*/
|
||||
async function getBatchVector(source, sourceSettings, texts, directories) {
|
||||
async function getBatchVector(source, sourceSettings, texts, isQuery, directories) {
|
||||
const batchSize = 10;
|
||||
const batches = Array(Math.ceil(texts.length / batchSize)).fill(undefined).map((_, i) => texts.slice(i * batchSize, i * batchSize + batchSize));
|
||||
|
||||
@@ -50,21 +54,24 @@ async function getBatchVector(source, sourceSettings, texts, directories) {
|
||||
for (let batch of batches) {
|
||||
switch (source) {
|
||||
case 'nomicai':
|
||||
results.push(...await require('../nomicai-vectors').getNomicAIBatchVector(batch, source, directories));
|
||||
results.push(...await require('../vectors/nomicai-vectors').getNomicAIBatchVector(batch, source, directories));
|
||||
break;
|
||||
case 'togetherai':
|
||||
case 'mistral':
|
||||
case 'openai':
|
||||
results.push(...await require('../openai-vectors').getOpenAIBatchVector(batch, source, directories, sourceSettings.model));
|
||||
results.push(...await require('../vectors/openai-vectors').getOpenAIBatchVector(batch, source, directories, sourceSettings.model));
|
||||
break;
|
||||
case 'transformers':
|
||||
results.push(...await require('../embedding').getTransformersBatchVector(batch));
|
||||
results.push(...await require('../vectors/embedding').getTransformersBatchVector(batch));
|
||||
break;
|
||||
case 'extras':
|
||||
results.push(...await require('../extras-vectors').getExtrasBatchVector(batch, sourceSettings.extrasUrl, sourceSettings.extrasKey));
|
||||
results.push(...await require('../vectors/extras-vectors').getExtrasBatchVector(batch, sourceSettings.extrasUrl, sourceSettings.extrasKey));
|
||||
break;
|
||||
case 'palm':
|
||||
results.push(...await require('../makersuite-vectors').getMakerSuiteBatchVector(batch, directories));
|
||||
results.push(...await require('../vectors/makersuite-vectors').getMakerSuiteBatchVector(batch, directories));
|
||||
break;
|
||||
case 'cohere':
|
||||
results.push(...await require('../vectors/cohere-vectors').getCohereBatchVector(batch, isQuery, directories, sourceSettings.model));
|
||||
break;
|
||||
default:
|
||||
throw new Error(`Unknown vector source ${source}`);
|
||||
@@ -106,7 +113,7 @@ async function insertVectorItems(directories, collectionId, source, sourceSettin
|
||||
|
||||
await store.beginUpdate();
|
||||
|
||||
const vectors = await getBatchVector(source, sourceSettings, items.map(x => x.text), directories);
|
||||
const vectors = await getBatchVector(source, sourceSettings, items.map(x => x.text), false, directories);
|
||||
|
||||
for (let i = 0; i < items.length; i++) {
|
||||
const item = items[i];
|
||||
@@ -165,7 +172,7 @@ async function deleteVectorItems(directories, collectionId, source, hashes) {
|
||||
*/
|
||||
async function queryCollection(directories, collectionId, source, sourceSettings, searchText, topK) {
|
||||
const store = await getIndex(directories, collectionId, source);
|
||||
const vector = await getVector(source, sourceSettings, searchText, directories);
|
||||
const vector = await getVector(source, sourceSettings, searchText, true, directories);
|
||||
|
||||
const result = await store.queryItems(vector, topK);
|
||||
const metadata = result.map(x => x.item.metadata);
|
||||
@@ -184,7 +191,7 @@ async function queryCollection(directories, collectionId, source, sourceSettings
|
||||
* @returns {Promise<Record<string, { hashes: number[], metadata: object[] }>>} - The top K results from each collection
|
||||
*/
|
||||
async function multiQueryCollection(directories, collectionIds, source, sourceSettings, searchText, topK) {
|
||||
const vector = await getVector(source, sourceSettings, searchText, directories);
|
||||
const vector = await getVector(source, sourceSettings, searchText, true, directories);
|
||||
const results = [];
|
||||
|
||||
for (const collectionId of collectionIds) {
|
||||
@@ -223,18 +230,24 @@ async function multiQueryCollection(directories, collectionIds, source, sourceSe
|
||||
*/
|
||||
function getSourceSettings(source, request) {
|
||||
if (source === 'togetherai') {
|
||||
let model = String(request.headers['x-togetherai-model']);
|
||||
const model = String(request.headers['x-togetherai-model']);
|
||||
|
||||
return {
|
||||
model: model,
|
||||
};
|
||||
} else if (source === 'openai') {
|
||||
let model = String(request.headers['x-openai-model']);
|
||||
const model = String(request.headers['x-openai-model']);
|
||||
|
||||
return {
|
||||
model: model,
|
||||
};
|
||||
} else {
|
||||
} else if (source === 'cohere') {
|
||||
const model = String(request.headers['x-cohere-model']);
|
||||
|
||||
return {
|
||||
model: model,
|
||||
};
|
||||
}else {
|
||||
// Extras API settings to connect to the Extras embeddings provider
|
||||
let extrasUrl = '';
|
||||
let extrasKey = '';
|
||||
|
65
src/vectors/cohere-vectors.js
Normal file
65
src/vectors/cohere-vectors.js
Normal file
@@ -0,0 +1,65 @@
|
||||
const fetch = require('node-fetch').default;
|
||||
const { SECRET_KEYS, readSecret } = require('../endpoints/secrets');
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text batch from an OpenAI compatible endpoint.
|
||||
* @param {string[]} texts - The array of texts to get the vector for
|
||||
* @param {boolean} isQuery - If the text is a query for embedding search
|
||||
* @param {import('../users').UserDirectoryList} directories - The directories object for the user
|
||||
* @param {string} model - The model to use for the embedding
|
||||
* @returns {Promise<number[][]>} - The array of vectors for the texts
|
||||
*/
|
||||
async function getCohereBatchVector(texts, isQuery, directories, model) {
|
||||
const key = readSecret(directories, SECRET_KEYS.COHERE);
|
||||
|
||||
if (!key) {
|
||||
console.log('No API key found');
|
||||
throw new Error('No API key found');
|
||||
}
|
||||
|
||||
const response = await fetch('https://api.cohere.ai/v1/embed', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${key}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
texts: texts,
|
||||
model: model,
|
||||
input_type: isQuery ? 'search_query' : 'search_document',
|
||||
truncate: 'END',
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const text = await response.text();
|
||||
console.log('API request failed', response.statusText, text);
|
||||
throw new Error('API request failed');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
if (!Array.isArray(data?.embeddings)) {
|
||||
console.log('API response was not an array');
|
||||
throw new Error('API response was not an array');
|
||||
}
|
||||
|
||||
return data.embeddings;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from an OpenAI compatible endpoint.
|
||||
* @param {string} text - The text to get the vector for
|
||||
* @param {boolean} isQuery - If the text is a query for embedding search
|
||||
* @param {import('../users').UserDirectoryList} directories - The directories object for the user
|
||||
* @param {string} model - The model to use for the embedding
|
||||
* @returns {Promise<number[]>} - The vector for the text
|
||||
*/
|
||||
async function getCohereVector(text, isQuery, directories, model) {
|
||||
const vectors = await getCohereBatchVector([text], isQuery, directories, model);
|
||||
return vectors[0];
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getCohereBatchVector,
|
||||
getCohereVector,
|
||||
};
|
@@ -6,7 +6,7 @@ const TASK = 'feature-extraction';
|
||||
* @returns {Promise<number[]>} - The vectorized text in form of an array of numbers
|
||||
*/
|
||||
async function getTransformersVector(text) {
|
||||
const module = await import('./transformers.mjs');
|
||||
const module = await import('../transformers.mjs');
|
||||
const pipe = await module.default.getPipeline(TASK);
|
||||
const result = await pipe(text, { pooling: 'mean', normalize: true });
|
||||
const vector = Array.from(result.data);
|
@@ -1,10 +1,10 @@
|
||||
const fetch = require('node-fetch').default;
|
||||
const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
|
||||
const { SECRET_KEYS, readSecret } = require('../endpoints/secrets');
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from gecko model
|
||||
* @param {string[]} texts - The array of texts to get the vector for
|
||||
* @param {import('./users').UserDirectoryList} directories - The directories object for the user
|
||||
* @param {import('../users').UserDirectoryList} directories - The directories object for the user
|
||||
* @returns {Promise<number[][]>} - The array of vectors for the texts
|
||||
*/
|
||||
async function getMakerSuiteBatchVector(texts, directories) {
|
||||
@@ -16,7 +16,7 @@ async function getMakerSuiteBatchVector(texts, directories) {
|
||||
/**
|
||||
* Gets the vector for the given text from PaLM gecko model
|
||||
* @param {string} text - The text to get the vector for
|
||||
* @param {import('./users').UserDirectoryList} directories - The directories object for the user
|
||||
* @param {import('../users').UserDirectoryList} directories - The directories object for the user
|
||||
* @returns {Promise<number[]>} - The vector for the text
|
||||
*/
|
||||
async function getMakerSuiteVector(text, directories) {
|
@@ -1,5 +1,5 @@
|
||||
const fetch = require('node-fetch').default;
|
||||
const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
|
||||
const { SECRET_KEYS, readSecret } = require('../endpoints/secrets');
|
||||
|
||||
const SOURCES = {
|
||||
'nomicai': {
|
||||
@@ -13,7 +13,7 @@ const SOURCES = {
|
||||
* Gets the vector for the given text batch from an OpenAI compatible endpoint.
|
||||
* @param {string[]} texts - The array of texts to get the vector for
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {import('./users').UserDirectoryList} directories - The directories object for the user
|
||||
* @param {import('../users').UserDirectoryList} directories - The directories object for the user
|
||||
* @returns {Promise<number[][]>} - The array of vectors for the texts
|
||||
*/
|
||||
async function getNomicAIBatchVector(texts, source, directories) {
|
||||
@@ -64,7 +64,7 @@ async function getNomicAIBatchVector(texts, source, directories) {
|
||||
* Gets the vector for the given text from an OpenAI compatible endpoint.
|
||||
* @param {string} text - The text to get the vector for
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {import('./users').UserDirectoryList} directories - The directories object for the user
|
||||
* @param {import('../users').UserDirectoryList} directories - The directories object for the user
|
||||
* @returns {Promise<number[]>} - The vector for the text
|
||||
*/
|
||||
async function getNomicAIVector(text, source, directories) {
|
@@ -1,5 +1,5 @@
|
||||
const fetch = require('node-fetch').default;
|
||||
const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
|
||||
const { SECRET_KEYS, readSecret } = require('../endpoints/secrets');
|
||||
|
||||
const SOURCES = {
|
||||
'togetherai': {
|
||||
@@ -23,7 +23,7 @@ const SOURCES = {
|
||||
* Gets the vector for the given text batch from an OpenAI compatible endpoint.
|
||||
* @param {string[]} texts - The array of texts to get the vector for
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {import('./users').UserDirectoryList} directories - The directories object for the user
|
||||
* @param {import('../users').UserDirectoryList} directories - The directories object for the user
|
||||
* @param {string} model - The model to use for the embedding
|
||||
* @returns {Promise<number[][]>} - The array of vectors for the texts
|
||||
*/
|
||||
@@ -79,7 +79,7 @@ async function getOpenAIBatchVector(texts, source, directories, model = '') {
|
||||
* Gets the vector for the given text from an OpenAI compatible endpoint.
|
||||
* @param {string} text - The text to get the vector for
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {import('./users').UserDirectoryList} directories - The directories object for the user
|
||||
* @param {import('../users').UserDirectoryList} directories - The directories object for the user
|
||||
* @param {string} model - The model to use for the embedding
|
||||
* @returns {Promise<number[]>} - The vector for the text
|
||||
*/
|
Reference in New Issue
Block a user