Vectors: Don't use headers for source-specific fields in requests

This commit is contained in:
Cohee
2025-02-16 23:59:00 +02:00
parent a771dd5478
commit 058c86f3c1
2 changed files with 67 additions and 73 deletions

View File

@ -745,6 +745,44 @@ async function getQueryText(chat, initiator) {
return collapseNewlines(queryText).trim();
}
/**
* Gets common body parameters for vector requests.
* @returns {object}
*/
function getVectorsRequestBody() {
const body = {};
switch (settings.source) {
case 'extras':
body.extrasUrl = extension_settings.apiUrl;
body.extrasKey = extension_settings.apiKey;
break;
case 'togetherai':
body.model = extension_settings.vectors.togetherai_model;
break;
case 'openai':
body.model = extension_settings.vectors.openai_model;
break;
case 'cohere':
body.model = extension_settings.vectors.cohere_model;
break;
case 'ollama':
body.model = extension_settings.vectors.ollama_model;
body.apiUrl = textgenerationwebui_settings.server_urls[textgen_types.OLLAMA];
body.keep = !!extension_settings.vectors.ollama_keep;
break;
case 'llamacpp':
body.apiUrl = textgenerationwebui_settings.server_urls[textgen_types.LLAMACPP];
break;
case 'vllm':
body.apiUrl = textgenerationwebui_settings.server_urls[textgen_types.VLLM];
body.model = extension_settings.vectors.vllm_model;
break;
default:
break;
}
return body;
}
/**
* Gets the saved hashes for a collection
* @param {string} collectionId
@ -753,8 +791,9 @@ async function getQueryText(chat, initiator) {
async function getSavedHashes(collectionId) {
const response = await fetch('/api/vector/list', {
method: 'POST',
headers: getVectorHeaders(),
headers: getRequestHeaders(),
body: JSON.stringify({
...getVectorsRequestBody(),
collectionId: collectionId,
source: settings.source,
}),
@ -768,54 +807,6 @@ async function getSavedHashes(collectionId) {
return hashes;
}
function getVectorHeaders() {
const headers = getRequestHeaders();
switch (settings.source) {
case 'extras':
Object.assign(headers, {
'X-Extras-Url': extension_settings.apiUrl,
'X-Extras-Key': extension_settings.apiKey,
});
break;
case 'togetherai':
Object.assign(headers, {
'X-Togetherai-Model': extension_settings.vectors.togetherai_model,
});
break;
case 'openai':
Object.assign(headers, {
'X-OpenAI-Model': extension_settings.vectors.openai_model,
});
break;
case 'cohere':
Object.assign(headers, {
'X-Cohere-Model': extension_settings.vectors.cohere_model,
});
break;
case 'ollama':
Object.assign(headers, {
'X-Ollama-Model': extension_settings.vectors.ollama_model,
'X-Ollama-URL': textgenerationwebui_settings.server_urls[textgen_types.OLLAMA],
'X-Ollama-Keep': !!extension_settings.vectors.ollama_keep,
});
break;
case 'llamacpp':
Object.assign(headers, {
'X-LlamaCpp-URL': textgenerationwebui_settings.server_urls[textgen_types.LLAMACPP],
});
break;
case 'vllm':
Object.assign(headers, {
'X-Vllm-URL': textgenerationwebui_settings.server_urls[textgen_types.VLLM],
'X-Vllm-Model': extension_settings.vectors.vllm_model,
});
break;
default:
break;
}
return headers;
}
/**
* Inserts vector items into a collection
* @param {string} collectionId - The collection to insert into
@ -825,12 +816,11 @@ function getVectorHeaders() {
async function insertVectorItems(collectionId, items) {
throwIfSourceInvalid();
const headers = getVectorHeaders();
const response = await fetch('/api/vector/insert', {
method: 'POST',
headers: headers,
headers: getRequestHeaders(),
body: JSON.stringify({
...getVectorsRequestBody(),
collectionId: collectionId,
items: items,
source: settings.source,
@ -879,8 +869,9 @@ function throwIfSourceInvalid() {
async function deleteVectorItems(collectionId, hashes) {
const response = await fetch('/api/vector/delete', {
method: 'POST',
headers: getVectorHeaders(),
headers: getRequestHeaders(),
body: JSON.stringify({
...getVectorsRequestBody(),
collectionId: collectionId,
hashes: hashes,
source: settings.source,
@ -899,12 +890,11 @@ async function deleteVectorItems(collectionId, hashes) {
* @returns {Promise<{ hashes: number[], metadata: object[]}>} - Hashes of the results
*/
async function queryCollection(collectionId, searchText, topK) {
const headers = getVectorHeaders();
const response = await fetch('/api/vector/query', {
method: 'POST',
headers: headers,
headers: getRequestHeaders(),
body: JSON.stringify({
...getVectorsRequestBody(),
collectionId: collectionId,
searchText: searchText,
topK: topK,
@ -929,12 +919,11 @@ async function queryCollection(collectionId, searchText, topK) {
* @returns {Promise<Record<string, { hashes: number[], metadata: object[] }>>} - Results mapped to collection IDs
*/
async function queryMultipleCollections(collectionIds, searchText, topK, threshold) {
const headers = getVectorHeaders();
const response = await fetch('/api/vector/query-multi', {
method: 'POST',
headers: headers,
headers: getRequestHeaders(),
body: JSON.stringify({
...getVectorsRequestBody(),
collectionIds: collectionIds,
searchText: searchText,
topK: topK,
@ -965,8 +954,9 @@ async function purgeFileVectorIndex(fileUrl) {
const response = await fetch('/api/vector/purge', {
method: 'POST',
headers: getVectorHeaders(),
headers: getRequestHeaders(),
body: JSON.stringify({
...getVectorsRequestBody(),
collectionId: collectionId,
}),
});
@ -994,8 +984,9 @@ async function purgeVectorIndex(collectionId) {
const response = await fetch('/api/vector/purge', {
method: 'POST',
headers: getVectorHeaders(),
headers: getRequestHeaders(),
body: JSON.stringify({
...getVectorsRequestBody(),
collectionId: collectionId,
}),
});
@ -1019,7 +1010,10 @@ async function purgeAllVectorIndexes() {
try {
const response = await fetch('/api/vector/purge-all', {
method: 'POST',
headers: getVectorHeaders(),
headers: getRequestHeaders(),
body: JSON.stringify({
...getVectorsRequestBody(),
}),
});
if (!response.ok) {

View File

@ -132,35 +132,35 @@ function getSourceSettings(source, request) {
switch (source) {
case 'togetherai':
return {
model: String(request.headers['x-togetherai-model']),
model: String(request.body.model),
};
case 'openai':
return {
model: String(request.headers['x-openai-model']),
model: String(request.body.model),
};
case 'cohere':
return {
model: String(request.headers['x-cohere-model']),
model: String(request.body.model),
};
case 'llamacpp':
return {
apiUrl: String(request.headers['x-llamacpp-url']),
apiUrl: String(request.body.apiUrl),
};
case 'vllm':
return {
apiUrl: String(request.headers['x-vllm-url']),
model: String(request.headers['x-vllm-model']),
apiUrl: String(request.body.apiUrl),
model: String(request.body.model),
};
case 'ollama':
return {
apiUrl: String(request.headers['x-ollama-url']),
model: String(request.headers['x-ollama-model']),
keep: Boolean(request.headers['x-ollama-keep']),
apiUrl: String(request.body.apiUrl),
model: String(request.body.model),
keep: Boolean(request.body.keep),
};
case 'extras':
return {
extrasUrl: String(request.headers['x-extras-url']),
extrasKey: String(request.headers['x-extras-key']),
extrasUrl: String(request.body.extrasUrl),
extrasKey: String(request.body.extrasKey),
};
case 'transformers':
return {