mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-06-05 21:59:27 +02:00
Merge pull request #1737 from Technologicat/vectordb-with-extras
Initial support for Extras vectorizer, for Vector Storage
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
import { eventSource, event_types, extension_prompt_types, getCurrentChatId, getRequestHeaders, is_send_press, saveSettingsDebounced, setExtensionPrompt, substituteParams } from '../../../script.js';
|
import { eventSource, event_types, extension_prompt_types, getCurrentChatId, getRequestHeaders, is_send_press, saveSettingsDebounced, setExtensionPrompt, substituteParams } from '../../../script.js';
|
||||||
import { ModuleWorkerWrapper, extension_settings, getContext, renderExtensionTemplate } from '../../extensions.js';
|
import { ModuleWorkerWrapper, extension_settings, getContext, modules, renderExtensionTemplate } from '../../extensions.js';
|
||||||
import { collapseNewlines } from '../../power-user.js';
|
import { collapseNewlines } from '../../power-user.js';
|
||||||
import { SECRET_KEYS, secret_state } from '../../secrets.js';
|
import { SECRET_KEYS, secret_state } from '../../secrets.js';
|
||||||
import { debounce, getStringHash as calculateHash, waitUntilCondition, onlyUnique, splitRecursive } from '../../utils.js';
|
import { debounce, getStringHash as calculateHash, waitUntilCondition, onlyUnique, splitRecursive } from '../../utils.js';
|
||||||
@@ -152,8 +152,25 @@ async function synchronizeChat(batchSize = 5) {
|
|||||||
|
|
||||||
return newVectorItems.length - batchSize;
|
return newVectorItems.length - batchSize;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
/**
|
||||||
|
* Gets the error message for a given cause
|
||||||
|
* @param {string} cause Error cause key
|
||||||
|
* @returns {string} Error message
|
||||||
|
*/
|
||||||
|
function getErrorMessage(cause) {
|
||||||
|
switch (cause) {
|
||||||
|
case 'api_key_missing':
|
||||||
|
return 'API key missing. Save it in the "API Connections" panel.';
|
||||||
|
case 'extras_module_missing':
|
||||||
|
return 'Extras API must provide an "embeddings" module.';
|
||||||
|
default:
|
||||||
|
return 'Check server console for more details';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
console.error('Vectors: Failed to synchronize chat', error);
|
console.error('Vectors: Failed to synchronize chat', error);
|
||||||
const message = error.cause === 'api_key_missing' ? 'API key missing. Save it in the "API Connections" panel.' : 'Check server console for more details';
|
|
||||||
|
const message = getErrorMessage(error.cause);
|
||||||
toastr.error(message, 'Vectorization failed');
|
toastr.error(message, 'Vectorization failed');
|
||||||
return -1;
|
return -1;
|
||||||
} finally {
|
} finally {
|
||||||
@@ -411,6 +428,18 @@ async function getSavedHashes(collectionId) {
|
|||||||
return hashes;
|
return hashes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add headers for the Extras API source.
|
||||||
|
* @param {object} headers Headers object
|
||||||
|
*/
|
||||||
|
function addExtrasHeaders(headers) {
|
||||||
|
console.log(`Vector source is extras, populating API URL: ${extension_settings.apiUrl}`);
|
||||||
|
Object.assign(headers, {
|
||||||
|
'X-Extras-Url': extension_settings.apiUrl,
|
||||||
|
'X-Extras-Key': extension_settings.apiKey,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Inserts vector items into a collection
|
* Inserts vector items into a collection
|
||||||
* @param {string} collectionId - The collection to insert into
|
* @param {string} collectionId - The collection to insert into
|
||||||
@@ -424,9 +453,18 @@ async function insertVectorItems(collectionId, items) {
|
|||||||
throw new Error('Vectors: API key missing', { cause: 'api_key_missing' });
|
throw new Error('Vectors: API key missing', { cause: 'api_key_missing' });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (settings.source === 'extras' && !modules.includes('embeddings')) {
|
||||||
|
throw new Error('Vectors: Embeddings module missing', { cause: 'extras_module_missing' });
|
||||||
|
}
|
||||||
|
|
||||||
|
const headers = getRequestHeaders();
|
||||||
|
if (settings.source === 'extras') {
|
||||||
|
addExtrasHeaders(headers);
|
||||||
|
}
|
||||||
|
|
||||||
const response = await fetch('/api/vector/insert', {
|
const response = await fetch('/api/vector/insert', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: getRequestHeaders(),
|
headers: headers,
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
collectionId: collectionId,
|
collectionId: collectionId,
|
||||||
items: items,
|
items: items,
|
||||||
@@ -468,9 +506,14 @@ async function deleteVectorItems(collectionId, hashes) {
|
|||||||
* @returns {Promise<{ hashes: number[], metadata: object[]}>} - Hashes of the results
|
* @returns {Promise<{ hashes: number[], metadata: object[]}>} - Hashes of the results
|
||||||
*/
|
*/
|
||||||
async function queryCollection(collectionId, searchText, topK) {
|
async function queryCollection(collectionId, searchText, topK) {
|
||||||
|
const headers = getRequestHeaders();
|
||||||
|
if (settings.source === 'extras') {
|
||||||
|
addExtrasHeaders(headers);
|
||||||
|
}
|
||||||
|
|
||||||
const response = await fetch('/api/vector/query', {
|
const response = await fetch('/api/vector/query', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: getRequestHeaders(),
|
headers: headers,
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
collectionId: collectionId,
|
collectionId: collectionId,
|
||||||
searchText: searchText,
|
searchText: searchText,
|
||||||
|
@@ -11,6 +11,7 @@
|
|||||||
</label>
|
</label>
|
||||||
<select id="vectors_source" class="text_pole">
|
<select id="vectors_source" class="text_pole">
|
||||||
<option value="transformers">Local (Transformers)</option>
|
<option value="transformers">Local (Transformers)</option>
|
||||||
|
<option value="extras">Extras</option>
|
||||||
<option value="openai">OpenAI</option>
|
<option value="openai">OpenAI</option>
|
||||||
<option value="palm">Google MakerSuite (PaLM)</option>
|
<option value="palm">Google MakerSuite (PaLM)</option>
|
||||||
<option value="mistral">MistralAI</option>
|
<option value="mistral">MistralAI</option>
|
||||||
|
@@ -7,16 +7,19 @@ const { jsonParser } = require('../express-common');
|
|||||||
/**
|
/**
|
||||||
* Gets the vector for the given text from the given source.
|
* Gets the vector for the given text from the given source.
|
||||||
* @param {string} source - The source of the vector
|
* @param {string} source - The source of the vector
|
||||||
|
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
||||||
* @param {string} text - The text to get the vector for
|
* @param {string} text - The text to get the vector for
|
||||||
* @returns {Promise<number[]>} - The vector for the text
|
* @returns {Promise<number[]>} - The vector for the text
|
||||||
*/
|
*/
|
||||||
async function getVector(source, text) {
|
async function getVector(source, sourceSettings, text) {
|
||||||
switch (source) {
|
switch (source) {
|
||||||
case 'mistral':
|
case 'mistral':
|
||||||
case 'openai':
|
case 'openai':
|
||||||
return require('../openai-vectors').getOpenAIVector(text, source);
|
return require('../openai-vectors').getOpenAIVector(text, source);
|
||||||
case 'transformers':
|
case 'transformers':
|
||||||
return require('../embedding').getTransformersVector(text);
|
return require('../embedding').getTransformersVector(text);
|
||||||
|
case 'extras':
|
||||||
|
return require('../extras-vectors').getExtrasVector(text, sourceSettings.extrasUrl, sourceSettings.extrasKey);
|
||||||
case 'palm':
|
case 'palm':
|
||||||
return require('../makersuite-vectors').getMakerSuiteVector(text);
|
return require('../makersuite-vectors').getMakerSuiteVector(text);
|
||||||
}
|
}
|
||||||
@@ -27,16 +30,19 @@ async function getVector(source, text) {
|
|||||||
/**
|
/**
|
||||||
* Gets the vector for the given text batch from the given source.
|
* Gets the vector for the given text batch from the given source.
|
||||||
* @param {string} source - The source of the vector
|
* @param {string} source - The source of the vector
|
||||||
|
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
||||||
* @param {string[]} texts - The array of texts to get the vector for
|
* @param {string[]} texts - The array of texts to get the vector for
|
||||||
* @returns {Promise<number[][]>} - The array of vectors for the texts
|
* @returns {Promise<number[][]>} - The array of vectors for the texts
|
||||||
*/
|
*/
|
||||||
async function getBatchVector(source, texts) {
|
async function getBatchVector(source, sourceSettings, texts) {
|
||||||
switch (source) {
|
switch (source) {
|
||||||
case 'mistral':
|
case 'mistral':
|
||||||
case 'openai':
|
case 'openai':
|
||||||
return require('../openai-vectors').getOpenAIBatchVector(texts, source);
|
return require('../openai-vectors').getOpenAIBatchVector(texts, source);
|
||||||
case 'transformers':
|
case 'transformers':
|
||||||
return require('../embedding').getTransformersBatchVector(texts);
|
return require('../embedding').getTransformersBatchVector(texts);
|
||||||
|
case 'extras':
|
||||||
|
return require('../extras-vectors').getExtrasBatchVector(texts, sourceSettings.extrasUrl, sourceSettings.extrasKey);
|
||||||
case 'palm':
|
case 'palm':
|
||||||
return require('../makersuite-vectors').getMakerSuiteBatchVector(texts);
|
return require('../makersuite-vectors').getMakerSuiteBatchVector(texts);
|
||||||
}
|
}
|
||||||
@@ -65,14 +71,15 @@ async function getIndex(collectionId, source, create = true) {
|
|||||||
* Inserts items into the vector collection
|
* Inserts items into the vector collection
|
||||||
* @param {string} collectionId - The collection ID
|
* @param {string} collectionId - The collection ID
|
||||||
* @param {string} source - The source of the vector
|
* @param {string} source - The source of the vector
|
||||||
|
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
||||||
* @param {{ hash: number; text: string; index: number; }[]} items - The items to insert
|
* @param {{ hash: number; text: string; index: number; }[]} items - The items to insert
|
||||||
*/
|
*/
|
||||||
async function insertVectorItems(collectionId, source, items) {
|
async function insertVectorItems(collectionId, source, sourceSettings, items) {
|
||||||
const store = await getIndex(collectionId, source);
|
const store = await getIndex(collectionId, source);
|
||||||
|
|
||||||
await store.beginUpdate();
|
await store.beginUpdate();
|
||||||
|
|
||||||
const vectors = await getBatchVector(source, items.map(x => x.text));
|
const vectors = await getBatchVector(source, sourceSettings, items.map(x => x.text));
|
||||||
|
|
||||||
for (let i = 0; i < items.length; i++) {
|
for (let i = 0; i < items.length; i++) {
|
||||||
const item = items[i];
|
const item = items[i];
|
||||||
@@ -121,13 +128,14 @@ async function deleteVectorItems(collectionId, source, hashes) {
|
|||||||
* Gets the hashes of the items in the vector collection that match the search text
|
* Gets the hashes of the items in the vector collection that match the search text
|
||||||
* @param {string} collectionId - The collection ID
|
* @param {string} collectionId - The collection ID
|
||||||
* @param {string} source - The source of the vector
|
* @param {string} source - The source of the vector
|
||||||
|
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
||||||
* @param {string} searchText - The text to search for
|
* @param {string} searchText - The text to search for
|
||||||
* @param {number} topK - The number of results to return
|
* @param {number} topK - The number of results to return
|
||||||
* @returns {Promise<{hashes: number[], metadata: object[]}>} - The metadata of the items that match the search text
|
* @returns {Promise<{hashes: number[], metadata: object[]}>} - The metadata of the items that match the search text
|
||||||
*/
|
*/
|
||||||
async function queryCollection(collectionId, source, searchText, topK) {
|
async function queryCollection(collectionId, source, sourceSettings, searchText, topK) {
|
||||||
const store = await getIndex(collectionId, source);
|
const store = await getIndex(collectionId, source);
|
||||||
const vector = await getVector(source, searchText);
|
const vector = await getVector(source, sourceSettings, searchText);
|
||||||
|
|
||||||
const result = await store.queryItems(vector, topK);
|
const result = await store.queryItems(vector, topK);
|
||||||
const metadata = result.map(x => x.item.metadata);
|
const metadata = result.map(x => x.item.metadata);
|
||||||
@@ -135,6 +143,28 @@ async function queryCollection(collectionId, source, searchText, topK) {
|
|||||||
return { metadata, hashes };
|
return { metadata, hashes };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts settings for the vectorization sources from the HTTP request headers.
|
||||||
|
* @param {string} source - Which source to extract settings for.
|
||||||
|
* @param {object} request - The HTTP request object.
|
||||||
|
* @returns {object} - An object that can be used as `sourceSettings` in functions that take that parameter.
|
||||||
|
*/
|
||||||
|
function getSourceSettings(source, request) {
|
||||||
|
// Extras API settings to connect to the Extras embeddings provider
|
||||||
|
let extrasUrl = '';
|
||||||
|
let extrasKey = '';
|
||||||
|
if (source === 'extras') {
|
||||||
|
extrasUrl = String(request.headers['x-extras-url']);
|
||||||
|
extrasKey = String(request.headers['x-extras-key']);
|
||||||
|
}
|
||||||
|
|
||||||
|
const sourceSettings = {
|
||||||
|
extrasUrl: extrasUrl,
|
||||||
|
extrasKey: extrasKey
|
||||||
|
};
|
||||||
|
return sourceSettings;
|
||||||
|
}
|
||||||
|
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
|
|
||||||
router.post('/query', jsonParser, async (req, res) => {
|
router.post('/query', jsonParser, async (req, res) => {
|
||||||
@@ -147,8 +177,9 @@ router.post('/query', jsonParser, async (req, res) => {
|
|||||||
const searchText = String(req.body.searchText);
|
const searchText = String(req.body.searchText);
|
||||||
const topK = Number(req.body.topK) || 10;
|
const topK = Number(req.body.topK) || 10;
|
||||||
const source = String(req.body.source) || 'transformers';
|
const source = String(req.body.source) || 'transformers';
|
||||||
|
const sourceSettings = getSourceSettings(source, req);
|
||||||
|
|
||||||
const results = await queryCollection(collectionId, source, searchText, topK);
|
const results = await queryCollection(collectionId, source, sourceSettings, searchText, topK);
|
||||||
return res.json(results);
|
return res.json(results);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(error);
|
console.error(error);
|
||||||
@@ -165,8 +196,9 @@ router.post('/insert', jsonParser, async (req, res) => {
|
|||||||
const collectionId = String(req.body.collectionId);
|
const collectionId = String(req.body.collectionId);
|
||||||
const items = req.body.items.map(x => ({ hash: x.hash, text: x.text, index: x.index }));
|
const items = req.body.items.map(x => ({ hash: x.hash, text: x.text, index: x.index }));
|
||||||
const source = String(req.body.source) || 'transformers';
|
const source = String(req.body.source) || 'transformers';
|
||||||
|
const sourceSettings = getSourceSettings(source, req);
|
||||||
|
|
||||||
await insertVectorItems(collectionId, source, items);
|
await insertVectorItems(collectionId, source, sourceSettings, items);
|
||||||
return res.sendStatus(200);
|
return res.sendStatus(200);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(error);
|
console.error(error);
|
||||||
|
78
src/extras-vectors.js
Normal file
78
src/extras-vectors.js
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
const fetch = require('node-fetch').default;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the vector for the given text from SillyTavern-extras
|
||||||
|
* @param {string[]} texts - The array of texts to get the vectors for
|
||||||
|
* @param {string} apiUrl - The Extras API URL
|
||||||
|
* @param {string} apiKey - The Extras API key, or empty string if API key not enabled
|
||||||
|
* @returns {Promise<number[][]>} - The array of vectors for the texts
|
||||||
|
*/
|
||||||
|
async function getExtrasBatchVector(texts, apiUrl, apiKey) {
|
||||||
|
return getExtrasVectorImpl(texts, apiUrl, apiKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the vector for the given text from SillyTavern-extras
|
||||||
|
* @param {string} text - The text to get the vector for
|
||||||
|
* @param {string} apiUrl - The Extras API URL
|
||||||
|
* @param {string} apiKey - The Extras API key, or empty string if API key not enabled
|
||||||
|
* @returns {Promise<number[]>} - The vector for the text
|
||||||
|
*/
|
||||||
|
async function getExtrasVector(text, apiUrl, apiKey) {
|
||||||
|
return getExtrasVectorImpl(text, apiUrl, apiKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the vector for the given text from SillyTavern-extras
|
||||||
|
* @param {string|string[]} text - The text or texts to get the vector(s) for
|
||||||
|
* @param {string} apiUrl - The Extras API URL
|
||||||
|
* @param {string} apiKey - The Extras API key, or empty string if API key not enabled *
|
||||||
|
* @returns {Promise<Array>} - The vector for a single text if input is string, or the array of vectors for multiple texts if input is string[]
|
||||||
|
*/
|
||||||
|
async function getExtrasVectorImpl(text, apiUrl, apiKey) {
|
||||||
|
let url;
|
||||||
|
try {
|
||||||
|
url = new URL(apiUrl);
|
||||||
|
url.pathname = '/api/embeddings/compute';
|
||||||
|
}
|
||||||
|
catch (error) {
|
||||||
|
console.log('Failed to set up Extras API call:', error);
|
||||||
|
console.log('Extras API URL given was:', apiUrl);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
|
const headers = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
};
|
||||||
|
|
||||||
|
// Include the Extras API key, if enabled
|
||||||
|
if (apiKey && apiKey.length > 0) {
|
||||||
|
Object.assign(headers, {
|
||||||
|
'Authorization': `Bearer ${apiKey}`,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await fetch(url, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: headers,
|
||||||
|
body: JSON.stringify({
|
||||||
|
text: text, // The backend accepts {string|string[]} for one or multiple text items, respectively.
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const text = await response.text();
|
||||||
|
console.log('Extras request failed', response.statusText, text);
|
||||||
|
throw new Error('Extras request failed');
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
const vector = data.embedding; // `embedding`: number[] (one text item), or number[][] (multiple text items).
|
||||||
|
|
||||||
|
return vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
getExtrasVector,
|
||||||
|
getExtrasBatchVector,
|
||||||
|
};
|
Reference in New Issue
Block a user