Implement TogetherAI as vectorization provider

This commit is contained in:
Kristan Schlikow
2024-03-01 23:52:49 +01:00
parent 95c49029f7
commit adfb9c5097
4 changed files with 88 additions and 22 deletions

View File

@@ -12,6 +12,7 @@ const settings = {
// For both
source: 'transformers',
include_wi: false,
model: 'togethercomputer/m2-bert-80M-32k-retrieval',
// For chats
enabled_chats: false,
@@ -440,6 +441,17 @@ function addExtrasHeaders(headers) {
});
}
/**
* Add headers for the Extras API source.
* @param {object} headers Headers object
*/
function addTogetherAiHeaders(headers) {
Object.assign(headers, {
'X-Togetherai-Model': extension_settings.vectors.model,
});
}
/**
* Inserts vector items into a collection
* @param {string} collectionId - The collection to insert into
@@ -449,7 +461,8 @@ function addExtrasHeaders(headers) {
async function insertVectorItems(collectionId, items) {
if (settings.source === 'openai' && !secret_state[SECRET_KEYS.OPENAI] ||
settings.source === 'palm' && !secret_state[SECRET_KEYS.MAKERSUITE] ||
settings.source === 'mistral' && !secret_state[SECRET_KEYS.MISTRALAI]) {
settings.source === 'mistral' && !secret_state[SECRET_KEYS.MISTRALAI] ||
settings.source === 'togetherai' && !secret_state[SECRET_KEYS.TOGETHERAI]) {
throw new Error('Vectors: API key missing', { cause: 'api_key_missing' });
}
@@ -460,6 +473,8 @@ async function insertVectorItems(collectionId, items) {
const headers = getRequestHeaders();
if (settings.source === 'extras') {
addExtrasHeaders(headers);
} else if (settings.source === 'togetherai') {
addTogetherAiHeaders(headers);
}
const response = await fetch('/api/vector/insert', {
@@ -509,6 +524,8 @@ async function queryCollection(collectionId, searchText, topK) {
const headers = getRequestHeaders();
if (settings.source === 'extras') {
addExtrasHeaders(headers);
} else if (settings.source === 'togetherai') {
addTogetherAiHeaders(headers);
}
const response = await fetch('/api/vector/query', {
@@ -526,8 +543,7 @@ async function queryCollection(collectionId, searchText, topK) {
throw new Error(`Failed to query collection ${collectionId}`);
}
const results = await response.json();
return results;
return await response.json();
}
/**
@@ -617,6 +633,13 @@ jQuery(async () => {
}
Object.assign(settings, extension_settings.vectors);
if (settings.source === 'togetherai') {
$('#vectorsModel').show();
} else {
$('#vectorsModel').hide();
}
// Migrate from TensorFlow to Transformers
settings.source = settings.source !== 'local' ? settings.source : 'transformers';
$('#extensions_settings2').append(renderExtensionTemplate(MODULE_NAME, 'settings'));
@@ -634,6 +657,17 @@ jQuery(async () => {
});
$('#vectors_source').val(settings.source).on('change', () => {
settings.source = String($('#vectors_source').val());
if (settings.source === 'togetherai') {
$('#vectorsModel').show();
} else {
$('#vectorsModel').hide();
}
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_model').val(settings.model).on('change', () => {
settings.model = String($('#vectors_model').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});