Vectors: Don't use headers for source-specific fields in requests

This commit is contained in:
Cohee
2025-02-16 23:59:00 +02:00
parent a771dd5478
commit 058c86f3c1
2 changed files with 67 additions and 73 deletions

View File

@ -745,6 +745,44 @@ async function getQueryText(chat, initiator) {
return collapseNewlines(queryText).trim(); return collapseNewlines(queryText).trim();
} }
/**
* Gets common body parameters for vector requests.
* @returns {object}
*/
function getVectorsRequestBody() {
const body = {};
switch (settings.source) {
case 'extras':
body.extrasUrl = extension_settings.apiUrl;
body.extrasKey = extension_settings.apiKey;
break;
case 'togetherai':
body.model = extension_settings.vectors.togetherai_model;
break;
case 'openai':
body.model = extension_settings.vectors.openai_model;
break;
case 'cohere':
body.model = extension_settings.vectors.cohere_model;
break;
case 'ollama':
body.model = extension_settings.vectors.ollama_model;
body.apiUrl = textgenerationwebui_settings.server_urls[textgen_types.OLLAMA];
body.keep = !!extension_settings.vectors.ollama_keep;
break;
case 'llamacpp':
body.apiUrl = textgenerationwebui_settings.server_urls[textgen_types.LLAMACPP];
break;
case 'vllm':
body.apiUrl = textgenerationwebui_settings.server_urls[textgen_types.VLLM];
body.model = extension_settings.vectors.vllm_model;
break;
default:
break;
}
return body;
}
/** /**
* Gets the saved hashes for a collection * Gets the saved hashes for a collection
* @param {string} collectionId * @param {string} collectionId
@ -753,8 +791,9 @@ async function getQueryText(chat, initiator) {
async function getSavedHashes(collectionId) { async function getSavedHashes(collectionId) {
const response = await fetch('/api/vector/list', { const response = await fetch('/api/vector/list', {
method: 'POST', method: 'POST',
headers: getVectorHeaders(), headers: getRequestHeaders(),
body: JSON.stringify({ body: JSON.stringify({
...getVectorsRequestBody(),
collectionId: collectionId, collectionId: collectionId,
source: settings.source, source: settings.source,
}), }),
@ -768,54 +807,6 @@ async function getSavedHashes(collectionId) {
return hashes; return hashes;
} }
function getVectorHeaders() {
const headers = getRequestHeaders();
switch (settings.source) {
case 'extras':
Object.assign(headers, {
'X-Extras-Url': extension_settings.apiUrl,
'X-Extras-Key': extension_settings.apiKey,
});
break;
case 'togetherai':
Object.assign(headers, {
'X-Togetherai-Model': extension_settings.vectors.togetherai_model,
});
break;
case 'openai':
Object.assign(headers, {
'X-OpenAI-Model': extension_settings.vectors.openai_model,
});
break;
case 'cohere':
Object.assign(headers, {
'X-Cohere-Model': extension_settings.vectors.cohere_model,
});
break;
case 'ollama':
Object.assign(headers, {
'X-Ollama-Model': extension_settings.vectors.ollama_model,
'X-Ollama-URL': textgenerationwebui_settings.server_urls[textgen_types.OLLAMA],
'X-Ollama-Keep': !!extension_settings.vectors.ollama_keep,
});
break;
case 'llamacpp':
Object.assign(headers, {
'X-LlamaCpp-URL': textgenerationwebui_settings.server_urls[textgen_types.LLAMACPP],
});
break;
case 'vllm':
Object.assign(headers, {
'X-Vllm-URL': textgenerationwebui_settings.server_urls[textgen_types.VLLM],
'X-Vllm-Model': extension_settings.vectors.vllm_model,
});
break;
default:
break;
}
return headers;
}
/** /**
* Inserts vector items into a collection * Inserts vector items into a collection
* @param {string} collectionId - The collection to insert into * @param {string} collectionId - The collection to insert into
@ -825,12 +816,11 @@ function getVectorHeaders() {
async function insertVectorItems(collectionId, items) { async function insertVectorItems(collectionId, items) {
throwIfSourceInvalid(); throwIfSourceInvalid();
const headers = getVectorHeaders();
const response = await fetch('/api/vector/insert', { const response = await fetch('/api/vector/insert', {
method: 'POST', method: 'POST',
headers: headers, headers: getRequestHeaders(),
body: JSON.stringify({ body: JSON.stringify({
...getVectorsRequestBody(),
collectionId: collectionId, collectionId: collectionId,
items: items, items: items,
source: settings.source, source: settings.source,
@ -879,8 +869,9 @@ function throwIfSourceInvalid() {
async function deleteVectorItems(collectionId, hashes) { async function deleteVectorItems(collectionId, hashes) {
const response = await fetch('/api/vector/delete', { const response = await fetch('/api/vector/delete', {
method: 'POST', method: 'POST',
headers: getVectorHeaders(), headers: getRequestHeaders(),
body: JSON.stringify({ body: JSON.stringify({
...getVectorsRequestBody(),
collectionId: collectionId, collectionId: collectionId,
hashes: hashes, hashes: hashes,
source: settings.source, source: settings.source,
@ -899,12 +890,11 @@ async function deleteVectorItems(collectionId, hashes) {
* @returns {Promise<{ hashes: number[], metadata: object[]}>} - Hashes of the results * @returns {Promise<{ hashes: number[], metadata: object[]}>} - Hashes of the results
*/ */
async function queryCollection(collectionId, searchText, topK) { async function queryCollection(collectionId, searchText, topK) {
const headers = getVectorHeaders();
const response = await fetch('/api/vector/query', { const response = await fetch('/api/vector/query', {
method: 'POST', method: 'POST',
headers: headers, headers: getRequestHeaders(),
body: JSON.stringify({ body: JSON.stringify({
...getVectorsRequestBody(),
collectionId: collectionId, collectionId: collectionId,
searchText: searchText, searchText: searchText,
topK: topK, topK: topK,
@ -929,12 +919,11 @@ async function queryCollection(collectionId, searchText, topK) {
* @returns {Promise<Record<string, { hashes: number[], metadata: object[] }>>} - Results mapped to collection IDs * @returns {Promise<Record<string, { hashes: number[], metadata: object[] }>>} - Results mapped to collection IDs
*/ */
async function queryMultipleCollections(collectionIds, searchText, topK, threshold) { async function queryMultipleCollections(collectionIds, searchText, topK, threshold) {
const headers = getVectorHeaders();
const response = await fetch('/api/vector/query-multi', { const response = await fetch('/api/vector/query-multi', {
method: 'POST', method: 'POST',
headers: headers, headers: getRequestHeaders(),
body: JSON.stringify({ body: JSON.stringify({
...getVectorsRequestBody(),
collectionIds: collectionIds, collectionIds: collectionIds,
searchText: searchText, searchText: searchText,
topK: topK, topK: topK,
@ -965,8 +954,9 @@ async function purgeFileVectorIndex(fileUrl) {
const response = await fetch('/api/vector/purge', { const response = await fetch('/api/vector/purge', {
method: 'POST', method: 'POST',
headers: getVectorHeaders(), headers: getRequestHeaders(),
body: JSON.stringify({ body: JSON.stringify({
...getVectorsRequestBody(),
collectionId: collectionId, collectionId: collectionId,
}), }),
}); });
@ -994,8 +984,9 @@ async function purgeVectorIndex(collectionId) {
const response = await fetch('/api/vector/purge', { const response = await fetch('/api/vector/purge', {
method: 'POST', method: 'POST',
headers: getVectorHeaders(), headers: getRequestHeaders(),
body: JSON.stringify({ body: JSON.stringify({
...getVectorsRequestBody(),
collectionId: collectionId, collectionId: collectionId,
}), }),
}); });
@ -1019,7 +1010,10 @@ async function purgeAllVectorIndexes() {
try { try {
const response = await fetch('/api/vector/purge-all', { const response = await fetch('/api/vector/purge-all', {
method: 'POST', method: 'POST',
headers: getVectorHeaders(), headers: getRequestHeaders(),
body: JSON.stringify({
...getVectorsRequestBody(),
}),
}); });
if (!response.ok) { if (!response.ok) {

View File

@ -132,35 +132,35 @@ function getSourceSettings(source, request) {
switch (source) { switch (source) {
case 'togetherai': case 'togetherai':
return { return {
model: String(request.headers['x-togetherai-model']), model: String(request.body.model),
}; };
case 'openai': case 'openai':
return { return {
model: String(request.headers['x-openai-model']), model: String(request.body.model),
}; };
case 'cohere': case 'cohere':
return { return {
model: String(request.headers['x-cohere-model']), model: String(request.body.model),
}; };
case 'llamacpp': case 'llamacpp':
return { return {
apiUrl: String(request.headers['x-llamacpp-url']), apiUrl: String(request.body.apiUrl),
}; };
case 'vllm': case 'vllm':
return { return {
apiUrl: String(request.headers['x-vllm-url']), apiUrl: String(request.body.apiUrl),
model: String(request.headers['x-vllm-model']), model: String(request.body.model),
}; };
case 'ollama': case 'ollama':
return { return {
apiUrl: String(request.headers['x-ollama-url']), apiUrl: String(request.body.apiUrl),
model: String(request.headers['x-ollama-model']), model: String(request.body.model),
keep: Boolean(request.headers['x-ollama-keep']), keep: Boolean(request.body.keep),
}; };
case 'extras': case 'extras':
return { return {
extrasUrl: String(request.headers['x-extras-url']), extrasUrl: String(request.body.extrasUrl),
extrasKey: String(request.headers['x-extras-key']), extrasKey: String(request.body.extrasKey),
}; };
case 'transformers': case 'transformers':
return { return {