Add support for KoboldCpp embeddings in Vector Storage (#3795)

* Add support for KoboldCpp embeddings in vector processing

* Add validation for KoboldCpp embeddings to handle empty data

* Improve toast handling
This commit is contained in:
Cohee
2025-04-01 21:21:29 +03:00
committed by GitHub
parent 9c4404cae9
commit 80e821d12d
4 changed files with 113 additions and 1 deletions

View File

@ -565,6 +565,8 @@ async function retrieveFileChunks(queryText, collectionId) {
* @returns {Promise<boolean>} True if successful, false if not
*/
async function vectorizeFile(fileText, fileName, collectionId, chunkSize, overlapPercent) {
let toast = jQuery();
try {
if (settings.translate_files && typeof globalThis.translate === 'function') {
console.log(`Vectors: Translating file ${fileName} to English...`);
@ -574,7 +576,7 @@ async function vectorizeFile(fileText, fileName, collectionId, chunkSize, overla
const batchSize = getBatchSize();
const toastBody = $('<span>').text('This may take a while. Please wait...');
const toast = toastr.info(toastBody, `Ingesting file ${escapeHtml(fileName)}`, { closeButton: false, escapeHtml: false, timeOut: 0, extendedTimeOut: 0 });
toast = toastr.info(toastBody, `Ingesting file ${escapeHtml(fileName)}`, { closeButton: false, escapeHtml: false, timeOut: 0, extendedTimeOut: 0 });
const overlapSize = Math.round(chunkSize * overlapPercent / 100);
const delimiters = getChunkDelimiters();
// Overlap should not be included in chunk size. It will be later compensated by overlapChunks
@ -596,6 +598,7 @@ async function vectorizeFile(fileText, fileName, collectionId, chunkSize, overla
console.log(`Vectors: Inserted ${chunks.length} vector items for file ${fileName} into ${collectionId}`);
return true;
} catch (error) {
toastr.clear(toast);
toastr.error(String(error), 'Failed to vectorize file', { preventDuplicates: true });
console.error('Vectors: Failed to vectorize file', error);
return false;
@ -803,6 +806,12 @@ async function getAdditionalArgs(items) {
case 'webllm':
args.embeddings = await createWebLlmEmbeddings(items);
break;
case 'koboldcpp': {
const { embeddings, model } = await createKoboldCppEmbeddings(items);
args.embeddings = embeddings;
args.model = model;
break;
}
}
return args;
}
@ -872,6 +881,7 @@ function throwIfSourceInvalid() {
if (settings.source === 'ollama' && !textgenerationwebui_settings.server_urls[textgen_types.OLLAMA] ||
settings.source === 'vllm' && !textgenerationwebui_settings.server_urls[textgen_types.VLLM] ||
settings.source === 'koboldcpp' && !textgenerationwebui_settings.server_urls[textgen_types.KOBOLDCPP] ||
settings.source === 'llamacpp' && !textgenerationwebui_settings.server_urls[textgen_types.LLAMACPP]) {
throw new Error('Vectors: API URL missing', { cause: 'api_url_missing' });
}
@ -1071,6 +1081,7 @@ function toggleSettings() {
$('#vllm_vectorsModel').toggle(settings.source === 'vllm');
$('#nomicai_apiKey').toggle(settings.source === 'nomicai');
$('#webllm_vectorsModel').toggle(settings.source === 'webllm');
$('#koboldcpp_vectorsModel').toggle(settings.source === 'koboldcpp');
if (settings.source === 'webllm') {
loadWebLlmModels();
}
@ -1138,6 +1149,45 @@ async function createWebLlmEmbeddings(items) {
});
}
/**
* Creates KoboldCpp embeddings for a list of items.
* @param {string[]} items Items to embed
* @returns {Promise<{embeddings: Record<string, number[]>, model: string}>} Calculated embeddings
*/
async function createKoboldCppEmbeddings(items) {
const response = await fetch('/api/backends/kobold/embed', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({
items: items,
server: textgenerationwebui_settings.server_urls[textgen_types.KOBOLDCPP],
}),
});
if (!response.ok) {
throw new Error('Failed to get KoboldCpp embeddings');
}
const data = await response.json();
if (!Array.isArray(data.embeddings) || !data.model || data.embeddings.length !== items.length) {
throw new Error('Invalid response from KoboldCpp embeddings');
}
const embeddings = /** @type {Record<string, number[]>} */ ({});
for (let i = 0; i < data.embeddings.length; i++) {
if (!Array.isArray(data.embeddings[i]) || data.embeddings[i].length === 0) {
throw new Error('KoboldCpp returned an empty embedding. Reduce the chunk size and/or size threshold and try again.');
}
embeddings[items[i]] = data.embeddings[i];
}
return {
embeddings: embeddings,
model: data.model,
};
}
async function onPurgeClick() {
const chatId = getCurrentChatId();
if (!chatId) {

View File

@ -13,6 +13,7 @@
<option value="cohere">Cohere</option>
<option value="extras">Extras (deprecated)</option>
<option value="palm">Google AI Studio</option>
<option value="koboldcpp">KoboldCpp</option>
<option value="llamacpp">llama.cpp</option>
<option value="transformers" data-i18n="Local (Transformers)">Local (Transformers)</option>
<option value="mistral">MistralAI</option>
@ -55,6 +56,14 @@
Hint: Set the URL in the API connection settings.
</i>
</div>
<div class="flex-container flexFlowColumn" id="koboldcpp_vectorsModel">
<span>
Set the KoboldCpp URL in the Text Completion API connection settings.
</span>
<span>
Must use version 1.87 or higher and have an embedding model loaded.
</span>
</div>
<div class="flex-container flexFlowColumn" id="llamacpp_vectorsModel">
<span data-i18n="The server MUST be started with the --embedding flag to use this feature!">
The server MUST be started with the <code>--embedding</code> flag to use this feature!

View File

@ -237,3 +237,45 @@ router.post('/transcribe-audio', async function (request, response) {
response.status(500).send('Internal server error');
}
});
router.post('/embed', async function (request, response) {
try {
const { server, items } = request.body;
if (!server) {
console.warn('KoboldCpp URL is not set');
return response.sendStatus(400);
}
const headers = {};
setAdditionalHeadersByType(headers, TEXTGEN_TYPES.KOBOLDCPP, server, request.user.directories);
const embeddingsUrl = new URL(server);
embeddingsUrl.pathname = '/api/extra/embeddings';
const embeddingsResult = await fetch(embeddingsUrl, {
method: 'POST',
headers: {
...headers,
},
body: JSON.stringify({
input: items,
}),
});
/** @type {any} */
const data = await embeddingsResult.json();
if (!Array.isArray(data?.data)) {
console.warn('KoboldCpp API response was not an array');
return response.sendStatus(500);
}
const model = data.model || 'unknown';
const embeddings = data.data.map(x => Array.isArray(x) ? x[0] : x).sort((a, b) => a.index - b.index).map(x => x.embedding);
return response.json({ model, embeddings });
} catch (error) {
console.error('KoboldCpp embedding failed', error);
response.status(500).send('Internal server error');
}
});

View File

@ -31,6 +31,7 @@ const SOURCES = [
'llamacpp',
'vllm',
'webllm',
'koboldcpp',
];
/**
@ -66,6 +67,8 @@ async function getVector(source, sourceSettings, text, isQuery, directories) {
return getOllamaVector(text, sourceSettings.apiUrl, sourceSettings.model, sourceSettings.keep, directories);
case 'webllm':
return sourceSettings.embeddings[text];
case 'koboldcpp':
return sourceSettings.embeddings[text];
}
throw new Error(`Unknown vector source ${source}`);
@ -119,6 +122,9 @@ async function getBatchVector(source, sourceSettings, texts, isQuery, directorie
case 'webllm':
results.push(...texts.map(x => sourceSettings.embeddings[x]));
break;
case 'koboldcpp':
results.push(...texts.map(x => sourceSettings.embeddings[x]));
break;
default:
throw new Error(`Unknown vector source ${source}`);
}
@ -189,6 +195,11 @@ function getSourceSettings(source, request) {
model: String(request.body.model),
embeddings: request.body.embeddings ?? {},
};
case 'koboldcpp':
return {
model: String(request.body.model),
embeddings: request.body.embeddings ?? {},
};
default:
return {};
}