diff --git a/public/scripts/extensions/vectors/index.js b/public/scripts/extensions/vectors/index.js index 685609511..c6b120dd9 100644 --- a/public/scripts/extensions/vectors/index.js +++ b/public/scripts/extensions/vectors/index.js @@ -565,6 +565,8 @@ async function retrieveFileChunks(queryText, collectionId) { * @returns {Promise} True if successful, false if not */ async function vectorizeFile(fileText, fileName, collectionId, chunkSize, overlapPercent) { + let toast = jQuery(); + try { if (settings.translate_files && typeof globalThis.translate === 'function') { console.log(`Vectors: Translating file ${fileName} to English...`); @@ -574,7 +576,7 @@ async function vectorizeFile(fileText, fileName, collectionId, chunkSize, overla const batchSize = getBatchSize(); const toastBody = $('').text('This may take a while. Please wait...'); - const toast = toastr.info(toastBody, `Ingesting file ${escapeHtml(fileName)}`, { closeButton: false, escapeHtml: false, timeOut: 0, extendedTimeOut: 0 }); + toast = toastr.info(toastBody, `Ingesting file ${escapeHtml(fileName)}`, { closeButton: false, escapeHtml: false, timeOut: 0, extendedTimeOut: 0 }); const overlapSize = Math.round(chunkSize * overlapPercent / 100); const delimiters = getChunkDelimiters(); // Overlap should not be included in chunk size. It will be later compensated by overlapChunks @@ -596,6 +598,7 @@ async function vectorizeFile(fileText, fileName, collectionId, chunkSize, overla console.log(`Vectors: Inserted ${chunks.length} vector items for file ${fileName} into ${collectionId}`); return true; } catch (error) { + toastr.clear(toast); toastr.error(String(error), 'Failed to vectorize file', { preventDuplicates: true }); console.error('Vectors: Failed to vectorize file', error); return false; @@ -803,6 +806,12 @@ async function getAdditionalArgs(items) { case 'webllm': args.embeddings = await createWebLlmEmbeddings(items); break; + case 'koboldcpp': { + const { embeddings, model } = await createKoboldCppEmbeddings(items); + args.embeddings = embeddings; + args.model = model; + break; + } } return args; } @@ -872,6 +881,7 @@ function throwIfSourceInvalid() { if (settings.source === 'ollama' && !textgenerationwebui_settings.server_urls[textgen_types.OLLAMA] || settings.source === 'vllm' && !textgenerationwebui_settings.server_urls[textgen_types.VLLM] || + settings.source === 'koboldcpp' && !textgenerationwebui_settings.server_urls[textgen_types.KOBOLDCPP] || settings.source === 'llamacpp' && !textgenerationwebui_settings.server_urls[textgen_types.LLAMACPP]) { throw new Error('Vectors: API URL missing', { cause: 'api_url_missing' }); } @@ -1071,6 +1081,7 @@ function toggleSettings() { $('#vllm_vectorsModel').toggle(settings.source === 'vllm'); $('#nomicai_apiKey').toggle(settings.source === 'nomicai'); $('#webllm_vectorsModel').toggle(settings.source === 'webllm'); + $('#koboldcpp_vectorsModel').toggle(settings.source === 'koboldcpp'); if (settings.source === 'webllm') { loadWebLlmModels(); } @@ -1138,6 +1149,45 @@ async function createWebLlmEmbeddings(items) { }); } +/** + * Creates KoboldCpp embeddings for a list of items. + * @param {string[]} items Items to embed + * @returns {Promise<{embeddings: Record, model: string}>} Calculated embeddings + */ +async function createKoboldCppEmbeddings(items) { + const response = await fetch('/api/backends/kobold/embed', { + method: 'POST', + headers: getRequestHeaders(), + body: JSON.stringify({ + items: items, + server: textgenerationwebui_settings.server_urls[textgen_types.KOBOLDCPP], + }), + }); + + if (!response.ok) { + throw new Error('Failed to get KoboldCpp embeddings'); + } + + const data = await response.json(); + if (!Array.isArray(data.embeddings) || !data.model || data.embeddings.length !== items.length) { + throw new Error('Invalid response from KoboldCpp embeddings'); + } + + const embeddings = /** @type {Record} */ ({}); + for (let i = 0; i < data.embeddings.length; i++) { + if (!Array.isArray(data.embeddings[i]) || data.embeddings[i].length === 0) { + throw new Error('KoboldCpp returned an empty embedding. Reduce the chunk size and/or size threshold and try again.'); + } + + embeddings[items[i]] = data.embeddings[i]; + } + + return { + embeddings: embeddings, + model: data.model, + }; +} + async function onPurgeClick() { const chatId = getCurrentChatId(); if (!chatId) { diff --git a/public/scripts/extensions/vectors/settings.html b/public/scripts/extensions/vectors/settings.html index 381134f9f..ad5a59c0f 100644 --- a/public/scripts/extensions/vectors/settings.html +++ b/public/scripts/extensions/vectors/settings.html @@ -13,6 +13,7 @@ + @@ -55,6 +56,14 @@ Hint: Set the URL in the API connection settings. +
+ + Set the KoboldCpp URL in the Text Completion API connection settings. + + + Must use version 1.87 or higher and have an embedding model loaded. + +
The server MUST be started with the --embedding flag to use this feature! diff --git a/src/endpoints/backends/kobold.js b/src/endpoints/backends/kobold.js index 70bd5fac3..c0c158d62 100644 --- a/src/endpoints/backends/kobold.js +++ b/src/endpoints/backends/kobold.js @@ -237,3 +237,45 @@ router.post('/transcribe-audio', async function (request, response) { response.status(500).send('Internal server error'); } }); + +router.post('/embed', async function (request, response) { + try { + const { server, items } = request.body; + + if (!server) { + console.warn('KoboldCpp URL is not set'); + return response.sendStatus(400); + } + + const headers = {}; + setAdditionalHeadersByType(headers, TEXTGEN_TYPES.KOBOLDCPP, server, request.user.directories); + + const embeddingsUrl = new URL(server); + embeddingsUrl.pathname = '/api/extra/embeddings'; + + const embeddingsResult = await fetch(embeddingsUrl, { + method: 'POST', + headers: { + ...headers, + }, + body: JSON.stringify({ + input: items, + }), + }); + + /** @type {any} */ + const data = await embeddingsResult.json(); + + if (!Array.isArray(data?.data)) { + console.warn('KoboldCpp API response was not an array'); + return response.sendStatus(500); + } + + const model = data.model || 'unknown'; + const embeddings = data.data.map(x => Array.isArray(x) ? x[0] : x).sort((a, b) => a.index - b.index).map(x => x.embedding); + return response.json({ model, embeddings }); + } catch (error) { + console.error('KoboldCpp embedding failed', error); + response.status(500).send('Internal server error'); + } +}); diff --git a/src/endpoints/vectors.js b/src/endpoints/vectors.js index 9630ce862..f26eacaef 100644 --- a/src/endpoints/vectors.js +++ b/src/endpoints/vectors.js @@ -31,6 +31,7 @@ const SOURCES = [ 'llamacpp', 'vllm', 'webllm', + 'koboldcpp', ]; /** @@ -66,6 +67,8 @@ async function getVector(source, sourceSettings, text, isQuery, directories) { return getOllamaVector(text, sourceSettings.apiUrl, sourceSettings.model, sourceSettings.keep, directories); case 'webllm': return sourceSettings.embeddings[text]; + case 'koboldcpp': + return sourceSettings.embeddings[text]; } throw new Error(`Unknown vector source ${source}`); @@ -119,6 +122,9 @@ async function getBatchVector(source, sourceSettings, texts, isQuery, directorie case 'webllm': results.push(...texts.map(x => sourceSettings.embeddings[x])); break; + case 'koboldcpp': + results.push(...texts.map(x => sourceSettings.embeddings[x])); + break; default: throw new Error(`Unknown vector source ${source}`); } @@ -189,6 +195,11 @@ function getSourceSettings(source, request) { model: String(request.body.model), embeddings: request.body.embeddings ?? {}, }; + case 'koboldcpp': + return { + model: String(request.body.model), + embeddings: request.body.embeddings ?? {}, + }; default: return {}; }