mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-02-20 22:20:39 +01:00
Merge pull request #1879 from Dakraid/togetherai-vectorization-source
Implement TogetherAI as vectorization provider
This commit is contained in:
commit
181657cede
@ -12,6 +12,7 @@ const settings = {
|
||||
// For both
|
||||
source: 'transformers',
|
||||
include_wi: false,
|
||||
togetherai_model: 'togethercomputer/m2-bert-80M-32k-retrieval',
|
||||
|
||||
// For chats
|
||||
enabled_chats: false,
|
||||
@ -440,6 +441,16 @@ function addExtrasHeaders(headers) {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Add headers for the Extras API source.
|
||||
* @param {object} headers Headers object
|
||||
*/
|
||||
function addTogetherAiHeaders(headers) {
|
||||
Object.assign(headers, {
|
||||
'X-Togetherai-Model': extension_settings.vectors.togetherai_model,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Inserts vector items into a collection
|
||||
* @param {string} collectionId - The collection to insert into
|
||||
@ -449,7 +460,8 @@ function addExtrasHeaders(headers) {
|
||||
async function insertVectorItems(collectionId, items) {
|
||||
if (settings.source === 'openai' && !secret_state[SECRET_KEYS.OPENAI] ||
|
||||
settings.source === 'palm' && !secret_state[SECRET_KEYS.MAKERSUITE] ||
|
||||
settings.source === 'mistral' && !secret_state[SECRET_KEYS.MISTRALAI]) {
|
||||
settings.source === 'mistral' && !secret_state[SECRET_KEYS.MISTRALAI] ||
|
||||
settings.source === 'togetherai' && !secret_state[SECRET_KEYS.TOGETHERAI]) {
|
||||
throw new Error('Vectors: API key missing', { cause: 'api_key_missing' });
|
||||
}
|
||||
|
||||
@ -460,6 +472,8 @@ async function insertVectorItems(collectionId, items) {
|
||||
const headers = getRequestHeaders();
|
||||
if (settings.source === 'extras') {
|
||||
addExtrasHeaders(headers);
|
||||
} else if (settings.source === 'togetherai') {
|
||||
addTogetherAiHeaders(headers);
|
||||
}
|
||||
|
||||
const response = await fetch('/api/vector/insert', {
|
||||
@ -509,6 +523,8 @@ async function queryCollection(collectionId, searchText, topK) {
|
||||
const headers = getRequestHeaders();
|
||||
if (settings.source === 'extras') {
|
||||
addExtrasHeaders(headers);
|
||||
} else if (settings.source === 'togetherai') {
|
||||
addTogetherAiHeaders(headers);
|
||||
}
|
||||
|
||||
const response = await fetch('/api/vector/query', {
|
||||
@ -526,8 +542,7 @@ async function queryCollection(collectionId, searchText, topK) {
|
||||
throw new Error(`Failed to query collection ${collectionId}`);
|
||||
}
|
||||
|
||||
const results = await response.json();
|
||||
return results;
|
||||
return await response.json();
|
||||
}
|
||||
|
||||
/**
|
||||
@ -564,6 +579,7 @@ async function purgeVectorIndex(collectionId) {
|
||||
function toggleSettings() {
|
||||
$('#vectors_files_settings').toggle(!!settings.enabled_files);
|
||||
$('#vectors_chats_settings').toggle(!!settings.enabled_chats);
|
||||
$('#together_vectorsModel').toggle(settings.source === 'togetherai');
|
||||
}
|
||||
|
||||
async function onPurgeClick() {
|
||||
@ -617,6 +633,7 @@ jQuery(async () => {
|
||||
}
|
||||
|
||||
Object.assign(settings, extension_settings.vectors);
|
||||
|
||||
// Migrate from TensorFlow to Transformers
|
||||
settings.source = settings.source !== 'local' ? settings.source : 'transformers';
|
||||
$('#extensions_settings2').append(renderExtensionTemplate(MODULE_NAME, 'settings'));
|
||||
@ -636,6 +653,13 @@ jQuery(async () => {
|
||||
settings.source = String($('#vectors_source').val());
|
||||
Object.assign(extension_settings.vectors, settings);
|
||||
saveSettingsDebounced();
|
||||
toggleSettings();
|
||||
});
|
||||
|
||||
$('#vectors_togetherai_model').val(settings.togetherai_model).on('change', () => {
|
||||
settings.togetherai_model = String($('#vectors_togetherai_model').val());
|
||||
Object.assign(extension_settings.vectors, settings);
|
||||
saveSettingsDebounced();
|
||||
});
|
||||
$('#vectors_template').val(settings.template).on('input', () => {
|
||||
settings.template = String($('#vectors_template').val());
|
||||
|
@ -15,6 +15,22 @@
|
||||
<option value="openai">OpenAI</option>
|
||||
<option value="palm">Google MakerSuite (PaLM)</option>
|
||||
<option value="mistral">MistralAI</option>
|
||||
<option value="togetherai">TogetherAI</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="flex-container flexFlowColumn" id="together_vectorsModel">
|
||||
<label for="vectors_togetherai_model">
|
||||
Vectorization Model
|
||||
</label>
|
||||
<select id="vectors_togetherai_model" class="text_pole">
|
||||
<option value="togethercomputer/m2-bert-80M-32k-retrieval">M2-BERT-Retrieval-32k</option>
|
||||
<option value="togethercomputer/m2-bert-80M-8k-retrieval">M2-BERT-Retrieval-8k</option>
|
||||
<option value="togethercomputer/m2-bert-80M-2k-retrieval">M2-BERT-Retrieval-2K</option>
|
||||
<option value="WhereIsAI/UAE-Large-V1">UAE-Large-V1</option>
|
||||
<option value="BAAI/bge-large-en-v1.5">BAAI-Bge-Large-1p5</option>
|
||||
<option value="BAAI/bge-base-en-v1.5">BAAI-Bge-Base-1p5</option>
|
||||
<option value="sentence-transformers/msmarco-bert-base-dot-v5">Sentence-BERT</option>
|
||||
<option value="bert-base-uncased">Bert Base Uncased</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
|
@ -5,7 +5,7 @@ const sanitize = require('sanitize-filename');
|
||||
const { jsonParser } = require('../express-common');
|
||||
|
||||
// Don't forget to add new sources to the SOURCES array
|
||||
const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm'];
|
||||
const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm', 'togetherai'];
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from the given source.
|
||||
@ -16,9 +16,10 @@ const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm'];
|
||||
*/
|
||||
async function getVector(source, sourceSettings, text) {
|
||||
switch (source) {
|
||||
case 'togetherai':
|
||||
case 'mistral':
|
||||
case 'openai':
|
||||
return require('../openai-vectors').getOpenAIVector(text, source);
|
||||
return require('../openai-vectors').getOpenAIVector(text, source, sourceSettings.model);
|
||||
case 'transformers':
|
||||
return require('../embedding').getTransformersVector(text);
|
||||
case 'extras':
|
||||
@ -44,9 +45,10 @@ async function getBatchVector(source, sourceSettings, texts) {
|
||||
let results = [];
|
||||
for (let batch of batches) {
|
||||
switch (source) {
|
||||
case 'togetherai':
|
||||
case 'mistral':
|
||||
case 'openai':
|
||||
results.push(...await require('../openai-vectors').getOpenAIBatchVector(batch, source));
|
||||
results.push(...await require('../openai-vectors').getOpenAIBatchVector(batch, source, sourceSettings.model));
|
||||
break;
|
||||
case 'transformers':
|
||||
results.push(...await require('../embedding').getTransformersBatchVector(batch));
|
||||
@ -165,19 +167,26 @@ async function queryCollection(collectionId, source, sourceSettings, searchText,
|
||||
* @returns {object} - An object that can be used as `sourceSettings` in functions that take that parameter.
|
||||
*/
|
||||
function getSourceSettings(source, request) {
|
||||
// Extras API settings to connect to the Extras embeddings provider
|
||||
let extrasUrl = '';
|
||||
let extrasKey = '';
|
||||
if (source === 'extras') {
|
||||
extrasUrl = String(request.headers['x-extras-url']);
|
||||
extrasKey = String(request.headers['x-extras-key']);
|
||||
}
|
||||
if (source === 'togetherai') {
|
||||
let model = String(request.headers['x-togetherai-model']);
|
||||
|
||||
const sourceSettings = {
|
||||
extrasUrl: extrasUrl,
|
||||
extrasKey: extrasKey,
|
||||
};
|
||||
return sourceSettings;
|
||||
return {
|
||||
model: model,
|
||||
};
|
||||
} else {
|
||||
// Extras API settings to connect to the Extras embeddings provider
|
||||
let extrasUrl = '';
|
||||
let extrasKey = '';
|
||||
if (source === 'extras') {
|
||||
extrasUrl = String(request.headers['x-extras-url']);
|
||||
extrasKey = String(request.headers['x-extras-key']);
|
||||
}
|
||||
|
||||
return {
|
||||
extrasUrl: extrasUrl,
|
||||
extrasKey: extrasKey,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const router = express.Router();
|
||||
|
@ -2,6 +2,11 @@ const fetch = require('node-fetch').default;
|
||||
const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
|
||||
|
||||
const SOURCES = {
|
||||
'togetherai': {
|
||||
secretKey: SECRET_KEYS.TOGETHERAI,
|
||||
url: 'api.together.xyz',
|
||||
model: 'togethercomputer/m2-bert-80M-32k-retrieval',
|
||||
},
|
||||
'mistral': {
|
||||
secretKey: SECRET_KEYS.MISTRALAI,
|
||||
url: 'api.mistral.ai',
|
||||
@ -18,9 +23,10 @@ const SOURCES = {
|
||||
* Gets the vector for the given text batch from an OpenAI compatible endpoint.
|
||||
* @param {string[]} texts - The array of texts to get the vector for
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {string} model - The model to use for the embedding
|
||||
* @returns {Promise<number[][]>} - The array of vectors for the texts
|
||||
*/
|
||||
async function getOpenAIBatchVector(texts, source) {
|
||||
async function getOpenAIBatchVector(texts, source, model = '') {
|
||||
const config = SOURCES[source];
|
||||
|
||||
if (!config) {
|
||||
@ -44,7 +50,7 @@ async function getOpenAIBatchVector(texts, source) {
|
||||
},
|
||||
body: JSON.stringify({
|
||||
input: texts,
|
||||
model: config.model,
|
||||
model: model || config.model,
|
||||
}),
|
||||
});
|
||||
|
||||
@ -72,10 +78,11 @@ async function getOpenAIBatchVector(texts, source) {
|
||||
* Gets the vector for the given text from an OpenAI compatible endpoint.
|
||||
* @param {string} text - The text to get the vector for
|
||||
* @param {string} source - The source of the vector
|
||||
* @param model
|
||||
* @returns {Promise<number[]>} - The vector for the text
|
||||
*/
|
||||
async function getOpenAIVector(text, source) {
|
||||
const vectors = await getOpenAIBatchVector([text], source);
|
||||
async function getOpenAIVector(text, source, model = '') {
|
||||
const vectors = await getOpenAIBatchVector([text], source, model);
|
||||
return vectors[0];
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user