diff --git a/public/scripts/extensions.js b/public/scripts/extensions.js index 5fbcdbdb5..ff70819d5 100644 --- a/public/scripts/extensions.js +++ b/public/scripts/extensions.js @@ -1070,7 +1070,7 @@ export async function installExtension(url, global) { toastr.success(t`Extension '${response.display_name}' by ${response.author} (version ${response.version}) has been installed successfully!`, t`Extension installation successful`); console.debug(`Extension "${response.display_name}" has been installed successfully at ${response.extensionPath}`); await loadExtensionSettings({}, false, false); - await eventSource.emit(event_types.EXTENSION_SETTINGS_LOADED); + await eventSource.emit(event_types.EXTENSION_SETTINGS_LOADED, response); } /** diff --git a/public/scripts/extensions/vectors/index.js b/public/scripts/extensions/vectors/index.js index 60eea49f0..685609511 100644 --- a/public/scripts/extensions/vectors/index.js +++ b/public/scripts/extensions/vectors/index.js @@ -19,6 +19,7 @@ import { modules, renderExtensionTemplateAsync, doExtrasFetch, getApiUrl, + openThirdPartyExtensionMenu, } from '../../extensions.js'; import { collapseNewlines, registerDebugFunction } from '../../power-user.js'; import { SECRET_KEYS, secret_state, writeSecret } from '../../secrets.js'; @@ -34,6 +35,7 @@ import { SlashCommandEnumValue, enumTypes } from '../../slash-commands/SlashComm import { slashCommandReturnHelper } from '../../slash-commands/SlashCommandReturnHelper.js'; import { callGenericPopup, POPUP_RESULT, POPUP_TYPE } from '../../popup.js'; import { generateWebLlmChatPrompt, isWebLlmSupported } from '../shared.js'; +import { WebLlmVectorProvider } from './webllm.js'; /** * @typedef {object} HashedMessage @@ -60,6 +62,7 @@ const settings = { ollama_model: 'mxbai-embed-large', ollama_keep: false, vllm_model: '', + webllm_model: '', summarize: false, summarize_sent: false, summary_source: 'main', @@ -103,7 +106,7 @@ const settings = { }; const moduleWorker = new ModuleWorkerWrapper(synchronizeChat); - +const webllmProvider = new WebLlmVectorProvider(); const cachedSummaries = new Map(); /** @@ -373,6 +376,8 @@ async function synchronizeChat(batchSize = 5) { return 'Vectorization Source Model is required, but not set.'; case 'extras_module_missing': return 'Extras API must provide an "embeddings" module.'; + case 'webllm_not_supported': + return 'WebLLM extension is not installed or the model is not set.'; default: return 'Check server console for more details'; } @@ -747,14 +752,15 @@ async function getQueryText(chat, initiator) { /** * Gets common body parameters for vector requests. - * @returns {object} + * @param {object} args Additional arguments + * @returns {object} Request body */ -function getVectorsRequestBody() { - const body = {}; +function getVectorsRequestBody(args = {}) { + const body = Object.assign({}, args); switch (settings.source) { case 'extras': - body.extrasUrl = extension_settings.apiUrl; - body.extrasKey = extension_settings.apiKey; + body.extrasUrl = extension_settings.apiUrl; + body.extrasKey = extension_settings.apiKey; break; case 'togetherai': body.model = extension_settings.vectors.togetherai_model; @@ -777,12 +783,30 @@ function getVectorsRequestBody() { body.apiUrl = textgenerationwebui_settings.server_urls[textgen_types.VLLM]; body.model = extension_settings.vectors.vllm_model; break; + case 'webllm': + body.model = extension_settings.vectors.webllm_model; + break; default: break; } return body; } +/** + * Gets additional arguments for vector requests. + * @param {string[]} items Items to embed + * @returns {Promise} Additional arguments + */ +async function getAdditionalArgs(items) { + const args = {}; + switch (settings.source) { + case 'webllm': + args.embeddings = await createWebLlmEmbeddings(items); + break; + } + return args; +} + /** * Gets the saved hashes for a collection * @param {string} collectionId @@ -816,11 +840,12 @@ async function getSavedHashes(collectionId) { async function insertVectorItems(collectionId, items) { throwIfSourceInvalid(); + const args = await getAdditionalArgs(items.map(x => x.text)); const response = await fetch('/api/vector/insert', { method: 'POST', headers: getRequestHeaders(), body: JSON.stringify({ - ...getVectorsRequestBody(), + ...getVectorsRequestBody(args), collectionId: collectionId, items: items, source: settings.source, @@ -858,6 +883,10 @@ function throwIfSourceInvalid() { if (settings.source === 'extras' && !modules.includes('embeddings')) { throw new Error('Vectors: Embeddings module missing', { cause: 'extras_module_missing' }); } + + if (settings.source === 'webllm' && (!isWebLlmSupported() || !settings.webllm_model)) { + throw new Error('Vectors: WebLLM is not supported', { cause: 'webllm_not_supported' }); + } } /** @@ -890,11 +919,12 @@ async function deleteVectorItems(collectionId, hashes) { * @returns {Promise<{ hashes: number[], metadata: object[]}>} - Hashes of the results */ async function queryCollection(collectionId, searchText, topK) { + const args = await getAdditionalArgs([searchText]); const response = await fetch('/api/vector/query', { method: 'POST', headers: getRequestHeaders(), body: JSON.stringify({ - ...getVectorsRequestBody(), + ...getVectorsRequestBody(args), collectionId: collectionId, searchText: searchText, topK: topK, @@ -919,11 +949,12 @@ async function queryCollection(collectionId, searchText, topK) { * @returns {Promise>} - Results mapped to collection IDs */ async function queryMultipleCollections(collectionIds, searchText, topK, threshold) { + const args = await getAdditionalArgs([searchText]); const response = await fetch('/api/vector/query-multi', { method: 'POST', headers: getRequestHeaders(), body: JSON.stringify({ - ...getVectorsRequestBody(), + ...getVectorsRequestBody(args), collectionIds: collectionIds, searchText: searchText, topK: topK, @@ -1039,6 +1070,72 @@ function toggleSettings() { $('#llamacpp_vectorsModel').toggle(settings.source === 'llamacpp'); $('#vllm_vectorsModel').toggle(settings.source === 'vllm'); $('#nomicai_apiKey').toggle(settings.source === 'nomicai'); + $('#webllm_vectorsModel').toggle(settings.source === 'webllm'); + if (settings.source === 'webllm') { + loadWebLlmModels(); + } +} + +/** + * Executes a function with WebLLM error handling. + * @param {function(): Promise} func Function to execute + * @returns {Promise} + * @template T + */ +async function executeWithWebLlmErrorHandling(func) { + try { + return await func(); + } catch (error) { + console.log('Vectors: Failed to load WebLLM models', error); + if (!(error instanceof Error)) { + return; + } + switch (error.cause) { + case 'webllm-not-available': + toastr.warning('WebLLM is not available. Please install the extension.', 'WebLLM not installed'); + break; + case 'webllm-not-updated': + toastr.warning('The installed extension version does not support embeddings.', 'WebLLM update required'); + break; + } + } +} + +/** + * Loads and displays WebLLM models in the settings. + * @returns {Promise} + */ +function loadWebLlmModels() { + return executeWithWebLlmErrorHandling(() => { + const models = webllmProvider.getModels(); + $('#vectors_webllm_model').empty(); + for (const model of models) { + $('#vectors_webllm_model').append($('