Vectors WebLLM (#3631)

* Add WebLLM support for vectorization

* Load models when WebLLM extension installed

* Consistency updated

* Move checkWebLlm to initEngine

* Refactor vector request handling to use getAdditionalArgs

* Add error handling for unsupported WebLLM extension

* Add prefix to error causes
This commit is contained in:
Cohee
2025-03-09 00:51:44 +02:00
committed by GitHub
parent 0ea64050ff
commit 1cb9287684
5 changed files with 227 additions and 10 deletions

View File

@@ -19,6 +19,7 @@ import {
modules,
renderExtensionTemplateAsync,
doExtrasFetch, getApiUrl,
openThirdPartyExtensionMenu,
} from '../../extensions.js';
import { collapseNewlines, registerDebugFunction } from '../../power-user.js';
import { SECRET_KEYS, secret_state, writeSecret } from '../../secrets.js';
@@ -34,6 +35,7 @@ import { SlashCommandEnumValue, enumTypes } from '../../slash-commands/SlashComm
import { slashCommandReturnHelper } from '../../slash-commands/SlashCommandReturnHelper.js';
import { callGenericPopup, POPUP_RESULT, POPUP_TYPE } from '../../popup.js';
import { generateWebLlmChatPrompt, isWebLlmSupported } from '../shared.js';
import { WebLlmVectorProvider } from './webllm.js';
/**
* @typedef {object} HashedMessage
@@ -60,6 +62,7 @@ const settings = {
ollama_model: 'mxbai-embed-large',
ollama_keep: false,
vllm_model: '',
webllm_model: '',
summarize: false,
summarize_sent: false,
summary_source: 'main',
@@ -103,7 +106,7 @@ const settings = {
};
const moduleWorker = new ModuleWorkerWrapper(synchronizeChat);
const webllmProvider = new WebLlmVectorProvider();
const cachedSummaries = new Map();
/**
@@ -373,6 +376,8 @@ async function synchronizeChat(batchSize = 5) {
return 'Vectorization Source Model is required, but not set.';
case 'extras_module_missing':
return 'Extras API must provide an "embeddings" module.';
case 'webllm_not_supported':
return 'WebLLM extension is not installed or the model is not set.';
default:
return 'Check server console for more details';
}
@@ -747,14 +752,15 @@ async function getQueryText(chat, initiator) {
/**
* Gets common body parameters for vector requests.
* @returns {object}
* @param {object} args Additional arguments
* @returns {object} Request body
*/
function getVectorsRequestBody() {
const body = {};
function getVectorsRequestBody(args = {}) {
const body = Object.assign({}, args);
switch (settings.source) {
case 'extras':
body.extrasUrl = extension_settings.apiUrl;
body.extrasKey = extension_settings.apiKey;
body.extrasUrl = extension_settings.apiUrl;
body.extrasKey = extension_settings.apiKey;
break;
case 'togetherai':
body.model = extension_settings.vectors.togetherai_model;
@@ -777,12 +783,30 @@ function getVectorsRequestBody() {
body.apiUrl = textgenerationwebui_settings.server_urls[textgen_types.VLLM];
body.model = extension_settings.vectors.vllm_model;
break;
case 'webllm':
body.model = extension_settings.vectors.webllm_model;
break;
default:
break;
}
return body;
}
/**
* Gets additional arguments for vector requests.
* @param {string[]} items Items to embed
* @returns {Promise<object>} Additional arguments
*/
async function getAdditionalArgs(items) {
const args = {};
switch (settings.source) {
case 'webllm':
args.embeddings = await createWebLlmEmbeddings(items);
break;
}
return args;
}
/**
* Gets the saved hashes for a collection
* @param {string} collectionId
@@ -816,11 +840,12 @@ async function getSavedHashes(collectionId) {
async function insertVectorItems(collectionId, items) {
throwIfSourceInvalid();
const args = await getAdditionalArgs(items.map(x => x.text));
const response = await fetch('/api/vector/insert', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({
...getVectorsRequestBody(),
...getVectorsRequestBody(args),
collectionId: collectionId,
items: items,
source: settings.source,
@@ -858,6 +883,10 @@ function throwIfSourceInvalid() {
if (settings.source === 'extras' && !modules.includes('embeddings')) {
throw new Error('Vectors: Embeddings module missing', { cause: 'extras_module_missing' });
}
if (settings.source === 'webllm' && (!isWebLlmSupported() || !settings.webllm_model)) {
throw new Error('Vectors: WebLLM is not supported', { cause: 'webllm_not_supported' });
}
}
/**
@@ -890,11 +919,12 @@ async function deleteVectorItems(collectionId, hashes) {
* @returns {Promise<{ hashes: number[], metadata: object[]}>} - Hashes of the results
*/
async function queryCollection(collectionId, searchText, topK) {
const args = await getAdditionalArgs([searchText]);
const response = await fetch('/api/vector/query', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({
...getVectorsRequestBody(),
...getVectorsRequestBody(args),
collectionId: collectionId,
searchText: searchText,
topK: topK,
@@ -919,11 +949,12 @@ async function queryCollection(collectionId, searchText, topK) {
* @returns {Promise<Record<string, { hashes: number[], metadata: object[] }>>} - Results mapped to collection IDs
*/
async function queryMultipleCollections(collectionIds, searchText, topK, threshold) {
const args = await getAdditionalArgs([searchText]);
const response = await fetch('/api/vector/query-multi', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({
...getVectorsRequestBody(),
...getVectorsRequestBody(args),
collectionIds: collectionIds,
searchText: searchText,
topK: topK,
@@ -1039,6 +1070,72 @@ function toggleSettings() {
$('#llamacpp_vectorsModel').toggle(settings.source === 'llamacpp');
$('#vllm_vectorsModel').toggle(settings.source === 'vllm');
$('#nomicai_apiKey').toggle(settings.source === 'nomicai');
$('#webllm_vectorsModel').toggle(settings.source === 'webllm');
if (settings.source === 'webllm') {
loadWebLlmModels();
}
}
/**
* Executes a function with WebLLM error handling.
* @param {function(): Promise<T>} func Function to execute
* @returns {Promise<T>}
* @template T
*/
async function executeWithWebLlmErrorHandling(func) {
try {
return await func();
} catch (error) {
console.log('Vectors: Failed to load WebLLM models', error);
if (!(error instanceof Error)) {
return;
}
switch (error.cause) {
case 'webllm-not-available':
toastr.warning('WebLLM is not available. Please install the extension.', 'WebLLM not installed');
break;
case 'webllm-not-updated':
toastr.warning('The installed extension version does not support embeddings.', 'WebLLM update required');
break;
}
}
}
/**
* Loads and displays WebLLM models in the settings.
* @returns {Promise<void>}
*/
function loadWebLlmModels() {
return executeWithWebLlmErrorHandling(() => {
const models = webllmProvider.getModels();
$('#vectors_webllm_model').empty();
for (const model of models) {
$('#vectors_webllm_model').append($('<option>', { value: model.id, text: model.toString() }));
}
if (!settings.webllm_model || !models.some(x => x.id === settings.webllm_model)) {
if (models.length) {
settings.webllm_model = models[0].id;
}
}
$('#vectors_webllm_model').val(settings.webllm_model);
return Promise.resolve();
});
}
/**
* Creates WebLLM embeddings for a list of items.
* @param {string[]} items Items to embed
* @returns {Promise<Record<string, number[]>>} Calculated embeddings
*/
async function createWebLlmEmbeddings(items) {
return executeWithWebLlmErrorHandling(async () => {
const embeddings = await webllmProvider.embedTexts(items, settings.webllm_model);
const result = /** @type {Record<string, number[]>} */ ({});
for (let i = 0; i < items.length; i++) {
result[items[i]] = embeddings[i];
}
return result;
});
}
async function onPurgeClick() {
@@ -1567,6 +1664,30 @@ jQuery(async () => {
$('#dialogue_popup_input').val(presetModel);
});
$('#vectors_webllm_install').on('click', (e) => {
e.preventDefault();
e.stopPropagation();
if (Object.hasOwn(SillyTavern, 'llm')) {
toastr.info('WebLLM is already installed');
return;
}
openThirdPartyExtensionMenu('https://github.com/SillyTavern/Extension-WebLLM');
});
$('#vectors_webllm_model').on('input', () => {
settings.webllm_model = String($('#vectors_webllm_model').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_webllm_load').on('click', async () => {
if (!settings.webllm_model) return;
await webllmProvider.loadModel(settings.webllm_model);
toastr.success('WebLLM model loaded');
});
$('#api_key_nomicai').toggleClass('success', !!secret_state[SECRET_KEYS.NOMICAI]);
toggleSettings();
@@ -1578,6 +1699,11 @@ jQuery(async () => {
eventSource.on(event_types.CHAT_DELETED, purgeVectorIndex);
eventSource.on(event_types.GROUP_CHAT_DELETED, purgeVectorIndex);
eventSource.on(event_types.FILE_ATTACHMENT_DELETED, purgeFileVectorIndex);
eventSource.on(event_types.EXTENSION_SETTINGS_LOADED, async (manifest) => {
if (settings.source === 'webllm' && manifest?.display_name === 'WebLLM') {
await loadWebLlmModels();
}
});
SlashCommandParser.addCommandObject(SlashCommand.fromProps({
name: 'db-ingest',