mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-02-12 10:00:36 +01:00
Initial support for Extras vectorizer, for Vector Storage
This commit is contained in:
parent
958cf6a373
commit
8b43535352
@ -424,9 +424,18 @@ async function insertVectorItems(collectionId, items) {
|
||||
throw new Error('Vectors: API key missing', { cause: 'api_key_missing' });
|
||||
}
|
||||
|
||||
const headers = getRequestHeaders();
|
||||
if (settings.source === 'extras') {
|
||||
console.log(`Vector source is extras, populating API URL: ${extension_settings.apiUrl}`);
|
||||
Object.assign(headers, {
|
||||
'X-Extras-Url': extension_settings.apiUrl,
|
||||
'X-Extras-Key': extension_settings.apiKey
|
||||
});
|
||||
}
|
||||
|
||||
const response = await fetch('/api/vector/insert', {
|
||||
method: 'POST',
|
||||
headers: getRequestHeaders(),
|
||||
headers: headers,
|
||||
body: JSON.stringify({
|
||||
collectionId: collectionId,
|
||||
items: items,
|
||||
@ -468,9 +477,18 @@ async function deleteVectorItems(collectionId, hashes) {
|
||||
* @returns {Promise<{ hashes: number[], metadata: object[]}>} - Hashes of the results
|
||||
*/
|
||||
async function queryCollection(collectionId, searchText, topK) {
|
||||
const headers = getRequestHeaders();
|
||||
if (settings.source === 'extras') {
|
||||
console.log(`Vector source is extras, populating API URL: ${extension_settings.apiUrl}`);
|
||||
Object.assign(headers, {
|
||||
'X-Extras-Url': extension_settings.apiUrl,
|
||||
'X-Extras-Key': extension_settings.apiKey
|
||||
});
|
||||
}
|
||||
|
||||
const response = await fetch('/api/vector/query', {
|
||||
method: 'POST',
|
||||
headers: getRequestHeaders(),
|
||||
headers: headers,
|
||||
body: JSON.stringify({
|
||||
collectionId: collectionId,
|
||||
searchText: searchText,
|
||||
|
@ -11,6 +11,7 @@
|
||||
</label>
|
||||
<select id="vectors_source" class="text_pole">
|
||||
<option value="transformers">Local (Transformers)</option>
|
||||
<option value="extras">Extras</option>
|
||||
<option value="openai">OpenAI</option>
|
||||
<option value="palm">Google MakerSuite (PaLM)</option>
|
||||
<option value="mistral">MistralAI</option>
|
||||
|
@ -7,16 +7,19 @@ const { jsonParser } = require('../express-common');
|
||||
/**
|
||||
* Gets the vector for the given text from the given source.
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
||||
* @param {string} text - The text to get the vector for
|
||||
* @returns {Promise<number[]>} - The vector for the text
|
||||
*/
|
||||
async function getVector(source, text) {
|
||||
async function getVector(source, sourceSettings, text) {
|
||||
switch (source) {
|
||||
case 'mistral':
|
||||
case 'openai':
|
||||
return require('../openai-vectors').getOpenAIVector(text, source);
|
||||
case 'transformers':
|
||||
return require('../embedding').getTransformersVector(text);
|
||||
case 'extras':
|
||||
return require('../extras-vectors').getExtrasVector(text, sourceSettings.extrasUrl, sourceSettings.extrasKey);
|
||||
case 'palm':
|
||||
return require('../makersuite-vectors').getMakerSuiteVector(text);
|
||||
}
|
||||
@ -45,9 +48,10 @@ async function getIndex(collectionId, source, create = true) {
|
||||
* Inserts items into the vector collection
|
||||
* @param {string} collectionId - The collection ID
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
||||
* @param {{ hash: number; text: string; index: number; }[]} items - The items to insert
|
||||
*/
|
||||
async function insertVectorItems(collectionId, source, items) {
|
||||
async function insertVectorItems(collectionId, source, sourceSettings, items) {
|
||||
const store = await getIndex(collectionId, source);
|
||||
|
||||
await store.beginUpdate();
|
||||
@ -56,7 +60,7 @@ async function insertVectorItems(collectionId, source, items) {
|
||||
const text = item.text;
|
||||
const hash = item.hash;
|
||||
const index = item.index;
|
||||
const vector = await getVector(source, text);
|
||||
const vector = await getVector(source, sourceSettings, text);
|
||||
await store.upsertItem({ vector: vector, metadata: { hash, text, index } });
|
||||
}
|
||||
|
||||
@ -101,13 +105,14 @@ async function deleteVectorItems(collectionId, source, hashes) {
|
||||
* Gets the hashes of the items in the vector collection that match the search text
|
||||
* @param {string} collectionId - The collection ID
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
||||
* @param {string} searchText - The text to search for
|
||||
* @param {number} topK - The number of results to return
|
||||
* @returns {Promise<{hashes: number[], metadata: object[]}>} - The metadata of the items that match the search text
|
||||
*/
|
||||
async function queryCollection(collectionId, source, searchText, topK) {
|
||||
async function queryCollection(collectionId, source, sourceSettings, searchText, topK) {
|
||||
const store = await getIndex(collectionId, source);
|
||||
const vector = await getVector(source, searchText);
|
||||
const vector = await getVector(source, sourceSettings, searchText);
|
||||
|
||||
const result = await store.queryItems(vector, topK);
|
||||
const metadata = result.map(x => x.item.metadata);
|
||||
@ -128,7 +133,19 @@ router.post('/query', jsonParser, async (req, res) => {
|
||||
const topK = Number(req.body.topK) || 10;
|
||||
const source = String(req.body.source) || 'transformers';
|
||||
|
||||
const results = await queryCollection(collectionId, source, searchText, topK);
|
||||
// API settings for Extras embeddings provider
|
||||
let extrasUrl = '';
|
||||
let extrasKey = '';
|
||||
if (source === 'extras') {
|
||||
extrasUrl = String(req.headers['x-extras-url']);
|
||||
extrasKey = String(req.headers['x-extras-key']);
|
||||
}
|
||||
const sourceSettings = {
|
||||
extrasUrl: extrasUrl,
|
||||
extrasKey: extrasKey
|
||||
};
|
||||
|
||||
const results = await queryCollection(collectionId, source, sourceSettings, searchText, topK);
|
||||
return res.json(results);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
@ -146,7 +163,19 @@ router.post('/insert', jsonParser, async (req, res) => {
|
||||
const items = req.body.items.map(x => ({ hash: x.hash, text: x.text, index: x.index }));
|
||||
const source = String(req.body.source) || 'transformers';
|
||||
|
||||
await insertVectorItems(collectionId, source, items);
|
||||
// API settings for Extras embeddings provider
|
||||
let extrasUrl = '';
|
||||
let extrasKey = '';
|
||||
if (source === 'extras') {
|
||||
extrasUrl = String(req.headers['x-extras-url']);
|
||||
extrasKey = String(req.headers['x-extras-key']);
|
||||
}
|
||||
const sourceSettings = {
|
||||
extrasUrl: extrasUrl,
|
||||
extrasKey: extrasKey
|
||||
};
|
||||
|
||||
await insertVectorItems(collectionId, source, sourceSettings, items);
|
||||
return res.sendStatus(200);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
|
54
src/extras-vectors.js
Normal file
54
src/extras-vectors.js
Normal file
@ -0,0 +1,54 @@
|
||||
const fetch = require('node-fetch').default;
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from SillyTavern-extras
|
||||
* @param {string|Array} text - The text or texts to get the vector for
|
||||
* @param {string} apiUrl - The Extras API URL
|
||||
* @param {string} - The Extras API key, or empty string if API key not enabled
|
||||
* @returns {Promise<number[]>} - The vector for a single text, or the array of vectors for multiple texts
|
||||
*/
|
||||
async function getExtrasVector(text, apiUrl, apiKey) {
|
||||
let url;
|
||||
try {
|
||||
url = new URL(apiUrl);
|
||||
url.pathname = '/api/embeddings/compute';
|
||||
}
|
||||
catch (error) {
|
||||
console.log('Failed to set up Extras API call:', error);
|
||||
console.log('Extras API URL given was:', apiUrl);
|
||||
}
|
||||
|
||||
const headers = {
|
||||
'Content-Type': 'application/json',
|
||||
};
|
||||
|
||||
// Include the Extras API key, if enabled
|
||||
if (apiKey && apiKey.length > 0) {
|
||||
Object.assign(headers, {
|
||||
'Authorization': `Bearer ${apiKey}`,
|
||||
});
|
||||
}
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: headers,
|
||||
body: JSON.stringify({
|
||||
text: text, // The backend accepts {string|Array} for one or multiple text items, respectively.
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const text = await response.text();
|
||||
console.log('Extras request failed', response.statusText, text);
|
||||
throw new Error('Extras request failed');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
const vector = data.embedding; // `embedding`: Array (one text item), or Array of Array (multiple text items).
|
||||
|
||||
return vector;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getExtrasVector,
|
||||
};
|
Loading…
x
Reference in New Issue
Block a user