From 0ad756c9233caa71226cece47f6a5d0bb6314a00 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Wed, 24 Jan 2024 16:51:57 +0200 Subject: [PATCH] Add check for "embeddings" module. --- public/scripts/extensions/vectors/index.js | 49 ++++++++++++++++------ src/extras-vectors.js | 8 ++-- 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/public/scripts/extensions/vectors/index.js b/public/scripts/extensions/vectors/index.js index fa72261b4..807f42676 100644 --- a/public/scripts/extensions/vectors/index.js +++ b/public/scripts/extensions/vectors/index.js @@ -1,5 +1,5 @@ import { eventSource, event_types, extension_prompt_types, getCurrentChatId, getRequestHeaders, is_send_press, saveSettingsDebounced, setExtensionPrompt, substituteParams } from '../../../script.js'; -import { ModuleWorkerWrapper, extension_settings, getContext, renderExtensionTemplate } from '../../extensions.js'; +import { ModuleWorkerWrapper, extension_settings, getContext, modules, renderExtensionTemplate } from '../../extensions.js'; import { collapseNewlines } from '../../power-user.js'; import { SECRET_KEYS, secret_state } from '../../secrets.js'; import { debounce, getStringHash as calculateHash, waitUntilCondition, onlyUnique, splitRecursive } from '../../utils.js'; @@ -152,8 +152,25 @@ async function synchronizeChat(batchSize = 5) { return newVectorItems.length - batchSize; } catch (error) { + /** + * Gets the error message for a given cause + * @param {string} cause Error cause key + * @returns {string} Error message + */ + function getErrorMessage(cause) { + switch (cause) { + case 'api_key_missing': + return 'API key missing. Save it in the "API Connections" panel.'; + case 'extras_module_missing': + return 'Extras API must provide an "embeddings" module.'; + default: + return 'Check server console for more details'; + } + } + console.error('Vectors: Failed to synchronize chat', error); - const message = error.cause === 'api_key_missing' ? 'API key missing. Save it in the "API Connections" panel.' : 'Check server console for more details'; + + const message = getErrorMessage(error.cause); toastr.error(message, 'Vectorization failed'); return -1; } finally { @@ -411,6 +428,18 @@ async function getSavedHashes(collectionId) { return hashes; } +/** + * Add headers for the Extras API source. + * @param {object} headers Headers object + */ +function addExtrasHeaders(headers) { + console.log(`Vector source is extras, populating API URL: ${extension_settings.apiUrl}`); + Object.assign(headers, { + 'X-Extras-Url': extension_settings.apiUrl, + 'X-Extras-Key': extension_settings.apiKey, + }); +} + /** * Inserts vector items into a collection * @param {string} collectionId - The collection to insert into @@ -424,13 +453,13 @@ async function insertVectorItems(collectionId, items) { throw new Error('Vectors: API key missing', { cause: 'api_key_missing' }); } + if (settings.source === 'extras' && !modules.includes('embeddings')) { + throw new Error('Vectors: Embeddings module missing', { cause: 'extras_module_missing' }); + } + const headers = getRequestHeaders(); if (settings.source === 'extras') { - console.log(`Vector source is extras, populating API URL: ${extension_settings.apiUrl}`); - Object.assign(headers, { - 'X-Extras-Url': extension_settings.apiUrl, - 'X-Extras-Key': extension_settings.apiKey - }); + addExtrasHeaders(headers); } const response = await fetch('/api/vector/insert', { @@ -479,11 +508,7 @@ async function deleteVectorItems(collectionId, hashes) { async function queryCollection(collectionId, searchText, topK) { const headers = getRequestHeaders(); if (settings.source === 'extras') { - console.log(`Vector source is extras, populating API URL: ${extension_settings.apiUrl}`); - Object.assign(headers, { - 'X-Extras-Url': extension_settings.apiUrl, - 'X-Extras-Key': extension_settings.apiKey - }); + addExtrasHeaders(headers); } const response = await fetch('/api/vector/query', { diff --git a/src/extras-vectors.js b/src/extras-vectors.js index af9c1a896..56da20633 100644 --- a/src/extras-vectors.js +++ b/src/extras-vectors.js @@ -4,7 +4,7 @@ const fetch = require('node-fetch').default; * Gets the vector for the given text from SillyTavern-extras * @param {string[]} texts - The array of texts to get the vectors for * @param {string} apiUrl - The Extras API URL - * @param {string} - The Extras API key, or empty string if API key not enabled + * @param {string} apiKey - The Extras API key, or empty string if API key not enabled * @returns {Promise} - The array of vectors for the texts */ async function getExtrasBatchVector(texts, apiUrl, apiKey) { @@ -15,7 +15,7 @@ async function getExtrasBatchVector(texts, apiUrl, apiKey) { * Gets the vector for the given text from SillyTavern-extras * @param {string} text - The text to get the vector for * @param {string} apiUrl - The Extras API URL - * @param {string} - The Extras API key, or empty string if API key not enabled + * @param {string} apiKey - The Extras API key, or empty string if API key not enabled * @returns {Promise} - The vector for the text */ async function getExtrasVector(text, apiUrl, apiKey) { @@ -26,8 +26,8 @@ async function getExtrasVector(text, apiUrl, apiKey) { * Gets the vector for the given text from SillyTavern-extras * @param {string|string[]} text - The text or texts to get the vector(s) for * @param {string} apiUrl - The Extras API URL - * @param {string} - The Extras API key, or empty string if API key not enabled - * @returns {Promise|Promise} - The vector for a single text, or the array of vectors for multiple texts + * @param {string} apiKey - The Extras API key, or empty string if API key not enabled * + * @returns {Promise} - The vector for a single text if input is string, or the array of vectors for multiple texts if input is string[] */ async function getExtrasVectorImpl(text, apiUrl, apiKey) { let url;