mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-06-05 21:59:27 +02:00
Implement TogetherAI as vectorization provider
This commit is contained in:
@@ -12,6 +12,7 @@ const settings = {
|
||||
// For both
|
||||
source: 'transformers',
|
||||
include_wi: false,
|
||||
model: 'togethercomputer/m2-bert-80M-32k-retrieval',
|
||||
|
||||
// For chats
|
||||
enabled_chats: false,
|
||||
@@ -440,6 +441,17 @@ function addExtrasHeaders(headers) {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Add headers for the Extras API source.
|
||||
* @param {object} headers Headers object
|
||||
*/
|
||||
function addTogetherAiHeaders(headers) {
|
||||
Object.assign(headers, {
|
||||
'X-Togetherai-Model': extension_settings.vectors.model,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Inserts vector items into a collection
|
||||
* @param {string} collectionId - The collection to insert into
|
||||
@@ -449,7 +461,8 @@ function addExtrasHeaders(headers) {
|
||||
async function insertVectorItems(collectionId, items) {
|
||||
if (settings.source === 'openai' && !secret_state[SECRET_KEYS.OPENAI] ||
|
||||
settings.source === 'palm' && !secret_state[SECRET_KEYS.MAKERSUITE] ||
|
||||
settings.source === 'mistral' && !secret_state[SECRET_KEYS.MISTRALAI]) {
|
||||
settings.source === 'mistral' && !secret_state[SECRET_KEYS.MISTRALAI] ||
|
||||
settings.source === 'togetherai' && !secret_state[SECRET_KEYS.TOGETHERAI]) {
|
||||
throw new Error('Vectors: API key missing', { cause: 'api_key_missing' });
|
||||
}
|
||||
|
||||
@@ -460,6 +473,8 @@ async function insertVectorItems(collectionId, items) {
|
||||
const headers = getRequestHeaders();
|
||||
if (settings.source === 'extras') {
|
||||
addExtrasHeaders(headers);
|
||||
} else if (settings.source === 'togetherai') {
|
||||
addTogetherAiHeaders(headers);
|
||||
}
|
||||
|
||||
const response = await fetch('/api/vector/insert', {
|
||||
@@ -509,6 +524,8 @@ async function queryCollection(collectionId, searchText, topK) {
|
||||
const headers = getRequestHeaders();
|
||||
if (settings.source === 'extras') {
|
||||
addExtrasHeaders(headers);
|
||||
} else if (settings.source === 'togetherai') {
|
||||
addTogetherAiHeaders(headers);
|
||||
}
|
||||
|
||||
const response = await fetch('/api/vector/query', {
|
||||
@@ -526,8 +543,7 @@ async function queryCollection(collectionId, searchText, topK) {
|
||||
throw new Error(`Failed to query collection ${collectionId}`);
|
||||
}
|
||||
|
||||
const results = await response.json();
|
||||
return results;
|
||||
return await response.json();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -617,6 +633,13 @@ jQuery(async () => {
|
||||
}
|
||||
|
||||
Object.assign(settings, extension_settings.vectors);
|
||||
|
||||
if (settings.source === 'togetherai') {
|
||||
$('#vectorsModel').show();
|
||||
} else {
|
||||
$('#vectorsModel').hide();
|
||||
}
|
||||
|
||||
// Migrate from TensorFlow to Transformers
|
||||
settings.source = settings.source !== 'local' ? settings.source : 'transformers';
|
||||
$('#extensions_settings2').append(renderExtensionTemplate(MODULE_NAME, 'settings'));
|
||||
@@ -634,6 +657,17 @@ jQuery(async () => {
|
||||
});
|
||||
$('#vectors_source').val(settings.source).on('change', () => {
|
||||
settings.source = String($('#vectors_source').val());
|
||||
if (settings.source === 'togetherai') {
|
||||
$('#vectorsModel').show();
|
||||
} else {
|
||||
$('#vectorsModel').hide();
|
||||
}
|
||||
Object.assign(extension_settings.vectors, settings);
|
||||
saveSettingsDebounced();
|
||||
});
|
||||
|
||||
$('#vectors_model').val(settings.model).on('change', () => {
|
||||
settings.model = String($('#vectors_model').val());
|
||||
Object.assign(extension_settings.vectors, settings);
|
||||
saveSettingsDebounced();
|
||||
});
|
||||
|
@@ -15,6 +15,22 @@
|
||||
<option value="openai">OpenAI</option>
|
||||
<option value="palm">Google MakerSuite (PaLM)</option>
|
||||
<option value="mistral">MistralAI</option>
|
||||
<option value="togetherai">TogetherAI</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="flex-container flexFlowColumn" id="vectorsModel">
|
||||
<label for="vectors_model">
|
||||
Vectorization Model
|
||||
</label>
|
||||
<select id="vectors_model" class="text_pole">
|
||||
<option value="togethercomputer/m2-bert-80M-32k-retrieval">M2-BERT-Retrieval-32k</option>
|
||||
<option value="togethercomputer/m2-bert-80M-8k-retrieval">M2-BERT-Retrieval-8k</option>
|
||||
<option value="togethercomputer/m2-bert-80M-2k-retrieval">M2-BERT-Retrieval-2K</option>
|
||||
<option value="WhereIsAI/UAE-Large-V1">UAE-Large-V1</option>
|
||||
<option value="BAAI/bge-large-en-v1.5">BAAI-Bge-Large-1p5</option>
|
||||
<option value="BAAI/bge-base-en-v1.5">BAAI-Bge-Base-1p5</option>
|
||||
<option value="sentence-transformers/msmarco-bert-base-dot-v5">Sentence-BERT</option>
|
||||
<option value="bert-base-uncased">Bert Base Uncased</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
|
@@ -5,7 +5,7 @@ const sanitize = require('sanitize-filename');
|
||||
const { jsonParser } = require('../express-common');
|
||||
|
||||
// Don't forget to add new sources to the SOURCES array
|
||||
const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm'];
|
||||
const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm', 'togetherai'];
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from the given source.
|
||||
@@ -16,9 +16,10 @@ const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm'];
|
||||
*/
|
||||
async function getVector(source, sourceSettings, text) {
|
||||
switch (source) {
|
||||
case 'togetherai':
|
||||
case 'mistral':
|
||||
case 'openai':
|
||||
return require('../openai-vectors').getOpenAIVector(text, source);
|
||||
return require('../openai-vectors').getOpenAIVector(text, source, sourceSettings.model);
|
||||
case 'transformers':
|
||||
return require('../embedding').getTransformersVector(text);
|
||||
case 'extras':
|
||||
@@ -44,9 +45,10 @@ async function getBatchVector(source, sourceSettings, texts) {
|
||||
let results = [];
|
||||
for (let batch of batches) {
|
||||
switch (source) {
|
||||
case 'togetherai':
|
||||
case 'mistral':
|
||||
case 'openai':
|
||||
results.push(...await require('../openai-vectors').getOpenAIBatchVector(batch, source));
|
||||
results.push(...await require('../openai-vectors').getOpenAIBatchVector(batch, source, sourceSettings.model));
|
||||
break;
|
||||
case 'transformers':
|
||||
results.push(...await require('../embedding').getTransformersBatchVector(batch));
|
||||
@@ -165,6 +167,13 @@ async function queryCollection(collectionId, source, sourceSettings, searchText,
|
||||
* @returns {object} - An object that can be used as `sourceSettings` in functions that take that parameter.
|
||||
*/
|
||||
function getSourceSettings(source, request) {
|
||||
if (source === 'togetherai') {
|
||||
let model = String(request.headers['x-togetherai-model']);
|
||||
|
||||
return {
|
||||
model: model,
|
||||
};
|
||||
} else {
|
||||
// Extras API settings to connect to the Extras embeddings provider
|
||||
let extrasUrl = '';
|
||||
let extrasKey = '';
|
||||
@@ -173,11 +182,11 @@ function getSourceSettings(source, request) {
|
||||
extrasKey = String(request.headers['x-extras-key']);
|
||||
}
|
||||
|
||||
const sourceSettings = {
|
||||
return {
|
||||
extrasUrl: extrasUrl,
|
||||
extrasKey: extrasKey,
|
||||
};
|
||||
return sourceSettings;
|
||||
}
|
||||
}
|
||||
|
||||
const router = express.Router();
|
||||
|
@@ -2,6 +2,11 @@ const fetch = require('node-fetch').default;
|
||||
const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
|
||||
|
||||
const SOURCES = {
|
||||
'togetherai': {
|
||||
secretKey: SECRET_KEYS.TOGETHERAI,
|
||||
url: 'api.together.xyz',
|
||||
model: 'togethercomputer/m2-bert-80M-32k-retrieval',
|
||||
},
|
||||
'mistral': {
|
||||
secretKey: SECRET_KEYS.MISTRALAI,
|
||||
url: 'api.mistral.ai',
|
||||
@@ -18,9 +23,10 @@ const SOURCES = {
|
||||
* Gets the vector for the given text batch from an OpenAI compatible endpoint.
|
||||
* @param {string[]} texts - The array of texts to get the vector for
|
||||
* @param {string} source - The source of the vector
|
||||
* @param model
|
||||
* @returns {Promise<number[][]>} - The array of vectors for the texts
|
||||
*/
|
||||
async function getOpenAIBatchVector(texts, source) {
|
||||
async function getOpenAIBatchVector(texts, source, model = '') {
|
||||
const config = SOURCES[source];
|
||||
|
||||
if (!config) {
|
||||
@@ -44,7 +50,7 @@ async function getOpenAIBatchVector(texts, source) {
|
||||
},
|
||||
body: JSON.stringify({
|
||||
input: texts,
|
||||
model: config.model,
|
||||
model: source === 'togetherai' && model !== '' ? model : config.model,
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -72,10 +78,11 @@ async function getOpenAIBatchVector(texts, source) {
|
||||
* Gets the vector for the given text from an OpenAI compatible endpoint.
|
||||
* @param {string} text - The text to get the vector for
|
||||
* @param {string} source - The source of the vector
|
||||
* @param model
|
||||
* @returns {Promise<number[]>} - The vector for the text
|
||||
*/
|
||||
async function getOpenAIVector(text, source) {
|
||||
const vectors = await getOpenAIBatchVector([text], source);
|
||||
async function getOpenAIVector(text, source, model = '') {
|
||||
const vectors = await getOpenAIBatchVector([text], source, model);
|
||||
return vectors[0];
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user