From 058c86f3c1e7eba81029236f9d695196f1668e2a Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Sun, 16 Feb 2025 23:59:00 +0200 Subject: [PATCH] Vectors: Don't use headers for source-specific fields in requests --- public/scripts/extensions/vectors/index.js | 118 ++++++++++----------- src/endpoints/vectors.js | 22 ++-- 2 files changed, 67 insertions(+), 73 deletions(-) diff --git a/public/scripts/extensions/vectors/index.js b/public/scripts/extensions/vectors/index.js index aea42d909..5506bec00 100644 --- a/public/scripts/extensions/vectors/index.js +++ b/public/scripts/extensions/vectors/index.js @@ -745,6 +745,44 @@ async function getQueryText(chat, initiator) { return collapseNewlines(queryText).trim(); } +/** + * Gets common body parameters for vector requests. + * @returns {object} + */ +function getVectorsRequestBody() { + const body = {}; + switch (settings.source) { + case 'extras': + body.extrasUrl = extension_settings.apiUrl; + body.extrasKey = extension_settings.apiKey; + break; + case 'togetherai': + body.model = extension_settings.vectors.togetherai_model; + break; + case 'openai': + body.model = extension_settings.vectors.openai_model; + break; + case 'cohere': + body.model = extension_settings.vectors.cohere_model; + break; + case 'ollama': + body.model = extension_settings.vectors.ollama_model; + body.apiUrl = textgenerationwebui_settings.server_urls[textgen_types.OLLAMA]; + body.keep = !!extension_settings.vectors.ollama_keep; + break; + case 'llamacpp': + body.apiUrl = textgenerationwebui_settings.server_urls[textgen_types.LLAMACPP]; + break; + case 'vllm': + body.apiUrl = textgenerationwebui_settings.server_urls[textgen_types.VLLM]; + body.model = extension_settings.vectors.vllm_model; + break; + default: + break; + } + return body; +} + /** * Gets the saved hashes for a collection * @param {string} collectionId @@ -753,8 +791,9 @@ async function getQueryText(chat, initiator) { async function getSavedHashes(collectionId) { const response = await fetch('/api/vector/list', { method: 'POST', - headers: getVectorHeaders(), + headers: getRequestHeaders(), body: JSON.stringify({ + ...getVectorsRequestBody(), collectionId: collectionId, source: settings.source, }), @@ -768,54 +807,6 @@ async function getSavedHashes(collectionId) { return hashes; } -function getVectorHeaders() { - const headers = getRequestHeaders(); - switch (settings.source) { - case 'extras': - Object.assign(headers, { - 'X-Extras-Url': extension_settings.apiUrl, - 'X-Extras-Key': extension_settings.apiKey, - }); - break; - case 'togetherai': - Object.assign(headers, { - 'X-Togetherai-Model': extension_settings.vectors.togetherai_model, - }); - break; - case 'openai': - Object.assign(headers, { - 'X-OpenAI-Model': extension_settings.vectors.openai_model, - }); - break; - case 'cohere': - Object.assign(headers, { - 'X-Cohere-Model': extension_settings.vectors.cohere_model, - }); - break; - case 'ollama': - Object.assign(headers, { - 'X-Ollama-Model': extension_settings.vectors.ollama_model, - 'X-Ollama-URL': textgenerationwebui_settings.server_urls[textgen_types.OLLAMA], - 'X-Ollama-Keep': !!extension_settings.vectors.ollama_keep, - }); - break; - case 'llamacpp': - Object.assign(headers, { - 'X-LlamaCpp-URL': textgenerationwebui_settings.server_urls[textgen_types.LLAMACPP], - }); - break; - case 'vllm': - Object.assign(headers, { - 'X-Vllm-URL': textgenerationwebui_settings.server_urls[textgen_types.VLLM], - 'X-Vllm-Model': extension_settings.vectors.vllm_model, - }); - break; - default: - break; - } - return headers; -} - /** * Inserts vector items into a collection * @param {string} collectionId - The collection to insert into @@ -825,12 +816,11 @@ function getVectorHeaders() { async function insertVectorItems(collectionId, items) { throwIfSourceInvalid(); - const headers = getVectorHeaders(); - const response = await fetch('/api/vector/insert', { method: 'POST', - headers: headers, + headers: getRequestHeaders(), body: JSON.stringify({ + ...getVectorsRequestBody(), collectionId: collectionId, items: items, source: settings.source, @@ -879,8 +869,9 @@ function throwIfSourceInvalid() { async function deleteVectorItems(collectionId, hashes) { const response = await fetch('/api/vector/delete', { method: 'POST', - headers: getVectorHeaders(), + headers: getRequestHeaders(), body: JSON.stringify({ + ...getVectorsRequestBody(), collectionId: collectionId, hashes: hashes, source: settings.source, @@ -899,12 +890,11 @@ async function deleteVectorItems(collectionId, hashes) { * @returns {Promise<{ hashes: number[], metadata: object[]}>} - Hashes of the results */ async function queryCollection(collectionId, searchText, topK) { - const headers = getVectorHeaders(); - const response = await fetch('/api/vector/query', { method: 'POST', - headers: headers, + headers: getRequestHeaders(), body: JSON.stringify({ + ...getVectorsRequestBody(), collectionId: collectionId, searchText: searchText, topK: topK, @@ -929,12 +919,11 @@ async function queryCollection(collectionId, searchText, topK) { * @returns {Promise>} - Results mapped to collection IDs */ async function queryMultipleCollections(collectionIds, searchText, topK, threshold) { - const headers = getVectorHeaders(); - const response = await fetch('/api/vector/query-multi', { method: 'POST', - headers: headers, + headers: getRequestHeaders(), body: JSON.stringify({ + ...getVectorsRequestBody(), collectionIds: collectionIds, searchText: searchText, topK: topK, @@ -965,8 +954,9 @@ async function purgeFileVectorIndex(fileUrl) { const response = await fetch('/api/vector/purge', { method: 'POST', - headers: getVectorHeaders(), + headers: getRequestHeaders(), body: JSON.stringify({ + ...getVectorsRequestBody(), collectionId: collectionId, }), }); @@ -994,8 +984,9 @@ async function purgeVectorIndex(collectionId) { const response = await fetch('/api/vector/purge', { method: 'POST', - headers: getVectorHeaders(), + headers: getRequestHeaders(), body: JSON.stringify({ + ...getVectorsRequestBody(), collectionId: collectionId, }), }); @@ -1019,7 +1010,10 @@ async function purgeAllVectorIndexes() { try { const response = await fetch('/api/vector/purge-all', { method: 'POST', - headers: getVectorHeaders(), + headers: getRequestHeaders(), + body: JSON.stringify({ + ...getVectorsRequestBody(), + }), }); if (!response.ok) { diff --git a/src/endpoints/vectors.js b/src/endpoints/vectors.js index 3cab324d7..abab3b7d1 100644 --- a/src/endpoints/vectors.js +++ b/src/endpoints/vectors.js @@ -132,35 +132,35 @@ function getSourceSettings(source, request) { switch (source) { case 'togetherai': return { - model: String(request.headers['x-togetherai-model']), + model: String(request.body.model), }; case 'openai': return { - model: String(request.headers['x-openai-model']), + model: String(request.body.model), }; case 'cohere': return { - model: String(request.headers['x-cohere-model']), + model: String(request.body.model), }; case 'llamacpp': return { - apiUrl: String(request.headers['x-llamacpp-url']), + apiUrl: String(request.body.apiUrl), }; case 'vllm': return { - apiUrl: String(request.headers['x-vllm-url']), - model: String(request.headers['x-vllm-model']), + apiUrl: String(request.body.apiUrl), + model: String(request.body.model), }; case 'ollama': return { - apiUrl: String(request.headers['x-ollama-url']), - model: String(request.headers['x-ollama-model']), - keep: Boolean(request.headers['x-ollama-keep']), + apiUrl: String(request.body.apiUrl), + model: String(request.body.model), + keep: Boolean(request.body.keep), }; case 'extras': return { - extrasUrl: String(request.headers['x-extras-url']), - extrasKey: String(request.headers['x-extras-key']), + extrasUrl: String(request.body.extrasUrl), + extrasKey: String(request.body.extrasKey), }; case 'transformers': return {