Add ollama and llamacpp as vector sources

This commit is contained in:
Cohee 2024-05-28 22:54:50 +03:00
parent c858fccc5f
commit 2b3dfc5ae2
6 changed files with 286 additions and 20 deletions

View File

@ -25,6 +25,7 @@ import { getDataBankAttachments, getFileAttachment } from '../../chats.js';
import { debounce, getStringHash as calculateHash, waitUntilCondition, onlyUnique, splitRecursive } from '../../utils.js';
import { debounce_timeout } from '../../constants.js';
import { getSortedEntries } from '../../world-info.js';
import { textgen_types, textgenerationwebui_settings } from '../../textgen-settings.js';
const MODULE_NAME = 'vectors';
@ -38,6 +39,8 @@ const settings = {
togetherai_model: 'togethercomputer/m2-bert-80M-32k-retrieval',
openai_model: 'text-embedding-ada-002',
cohere_model: 'embed-english-v3.0',
ollama_model: 'mxbai-embed-large',
ollama_keep: false,
summarize: false,
summarize_sent: false,
summary_source: 'main',
@ -272,6 +275,10 @@ async function synchronizeChat(batchSize = 5) {
switch (cause) {
case 'api_key_missing':
return 'API key missing. Save it in the "API Connections" panel.';
case 'api_url_missing':
return 'API URL missing. Save it in the "API Connections" panel.';
case 'api_model_missing':
return 'Vectorization Source Model is required, but not set.';
case 'extras_module_missing':
return 'Extras API must provide an "embeddings" module.';
default:
@ -637,6 +644,12 @@ function getVectorHeaders() {
case 'cohere':
addCohereHeaders(headers);
break;
case 'ollama':
addOllamaHeaders(headers);
break;
case 'llamacpp':
addLlamaCppHeaders(headers);
break;
default:
break;
}
@ -685,6 +698,28 @@ function addCohereHeaders(headers) {
});
}
/**
* Add headers for the Ollama API source.
* @param {object} headers Header object
*/
function addOllamaHeaders(headers) {
Object.assign(headers, {
'X-Ollama-Model': extension_settings.vectors.ollama_model,
'X-Ollama-URL': textgenerationwebui_settings.server_urls[textgen_types.OLLAMA],
'X-Ollama-Keep': !!extension_settings.vectors.ollama_keep,
});
}
/**
* Add headers for the LlamaCpp API source.
* @param {object} headers Header object
*/
function addLlamaCppHeaders(headers) {
Object.assign(headers, {
'X-LlamaCpp-URL': textgenerationwebui_settings.server_urls[textgen_types.LLAMACPP],
});
}
/**
* Inserts vector items into a collection
* @param {string} collectionId - The collection to insert into
@ -692,18 +727,7 @@ function addCohereHeaders(headers) {
* @returns {Promise<void>}
*/
async function insertVectorItems(collectionId, items) {
if (settings.source === 'openai' && !secret_state[SECRET_KEYS.OPENAI] ||
settings.source === 'palm' && !secret_state[SECRET_KEYS.MAKERSUITE] ||
settings.source === 'mistral' && !secret_state[SECRET_KEYS.MISTRALAI] ||
settings.source === 'togetherai' && !secret_state[SECRET_KEYS.TOGETHERAI] ||
settings.source === 'nomicai' && !secret_state[SECRET_KEYS.NOMICAI] ||
settings.source === 'cohere' && !secret_state[SECRET_KEYS.COHERE]) {
throw new Error('Vectors: API key missing', { cause: 'api_key_missing' });
}
if (settings.source === 'extras' && !modules.includes('embeddings')) {
throw new Error('Vectors: Embeddings module missing', { cause: 'extras_module_missing' });
}
throwIfSourceInvalid();
const headers = getVectorHeaders();
@ -722,6 +746,33 @@ async function insertVectorItems(collectionId, items) {
}
}
/**
* Throws an error if the source is invalid (missing API key or URL, or missing module)
*/
function throwIfSourceInvalid() {
if (settings.source === 'openai' && !secret_state[SECRET_KEYS.OPENAI] ||
settings.source === 'palm' && !secret_state[SECRET_KEYS.MAKERSUITE] ||
settings.source === 'mistral' && !secret_state[SECRET_KEYS.MISTRALAI] ||
settings.source === 'togetherai' && !secret_state[SECRET_KEYS.TOGETHERAI] ||
settings.source === 'nomicai' && !secret_state[SECRET_KEYS.NOMICAI] ||
settings.source === 'cohere' && !secret_state[SECRET_KEYS.COHERE]) {
throw new Error('Vectors: API key missing', { cause: 'api_key_missing' });
}
if (settings.source === 'ollama' && !textgenerationwebui_settings.server_urls[textgen_types.OLLAMA] ||
settings.source === 'llamacpp' && !textgenerationwebui_settings.server_urls[textgen_types.LLAMACPP]) {
throw new Error('Vectors: API URL missing', { cause: 'api_url_missing' });
}
if (settings.source === 'ollama' && !settings.ollama_model) {
throw new Error('Vectors: API model missing', { cause: 'api_model_missing' });
}
if (settings.source === 'extras' && !modules.includes('embeddings')) {
throw new Error('Vectors: Embeddings module missing', { cause: 'extras_module_missing' });
}
}
/**
* Deletes vector items from a collection
* @param {string} collectionId - The collection to delete from
@ -870,6 +921,8 @@ function toggleSettings() {
$('#together_vectorsModel').toggle(settings.source === 'togetherai');
$('#openai_vectorsModel').toggle(settings.source === 'openai');
$('#cohere_vectorsModel').toggle(settings.source === 'cohere');
$('#ollama_vectorsModel').toggle(settings.source === 'ollama');
$('#llamacpp_vectorsModel').toggle(settings.source === 'llamacpp');
$('#nomicai_apiKey').toggle(settings.source === 'nomicai');
}
@ -1154,6 +1207,17 @@ jQuery(async () => {
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_ollama_model').val(settings.ollama_model).on('input', () => {
$('#vectors_modelWarning').show();
settings.ollama_model = String($('#vectors_ollama_model').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_ollama_keep').prop('checked', settings.ollama_keep).on('input', () => {
settings.ollama_keep = $('#vectors_ollama_keep').prop('checked');
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_template').val(settings.template).on('input', () => {
settings.template = String($('#vectors_template').val());
Object.assign(extension_settings.vectors, settings);

View File

@ -12,14 +12,37 @@
<select id="vectors_source" class="text_pole">
<option value="cohere">Cohere</option>
<option value="extras">Extras</option>
<option value="palm">Google MakerSuite (PaLM)</option>
<option value="palm">Google MakerSuite</option>
<option value="llamacpp">llama.cpp</option>
<option value="transformers">Local (Transformers)</option>
<option value="ollama">Ollama</option>
<option value="openai">OpenAI</option>
<option value="mistral">MistralAI</option>
<option value="nomicai">NomicAI</option>
<option value="openai">OpenAI</option>
<option value="togetherai">TogetherAI</option>
</select>
</div>
<div class="flex-container flexFlowColumn" id="ollama_vectorsModel">
<label for="vectors_ollama_model">
Vectorization Model
</label>
<input id="vectors_ollama_model" class="text_pole" type="text" placeholder="Model tag, e.g. llama3" />
<label for="vectors_ollama_keep" class="checkbox_label" title="When checked, the model will not be unloaded after use.">
<input id="vectors_ollama_keep" type="checkbox" />
<span>Keep model in memory</span>
</label>
<i>
Hint: Download models and set the URL in the API connection settings.
</i>
</div>
<div class="flex-container flexFlowColumn" id="llamacpp_vectorsModel">
<span>
The server MUST be started with the <code>--embedding</code> flag to use this feature!
</span>
<i>
Hint: Set the URL in the API connection settings.
</i>
</div>
<div class="flex-container flexFlowColumn" id="openai_vectorsModel">
<label for="vectors_openai_model">
Vectorization Model

View File

@ -164,6 +164,17 @@ function getOverrideHeaders(urlHost) {
* @param {string|null} server API server for new request
*/
function setAdditionalHeaders(request, args, server) {
setAdditionalHeadersByType(args.headers, request.body.api_type, server, request.user.directories);
}
/**
*
* @param {object} requestHeaders Request headers
* @param {string} type API type
* @param {string|null} server API server for new request
* @param {import('./users').UserDirectoryList} directories User directories
*/
function setAdditionalHeadersByType(requestHeaders, type, server, directories) {
const headerGetters = {
[TEXTGEN_TYPES.MANCER]: getMancerHeaders,
[TEXTGEN_TYPES.VLLM]: getVllmHeaders,
@ -178,13 +189,13 @@ function setAdditionalHeaders(request, args, server) {
[TEXTGEN_TYPES.LLAMACPP]: getLlamaCppHeaders,
};
const getHeaders = headerGetters[request.body.api_type];
const headers = getHeaders ? getHeaders(request.user.directories) : {};
const getHeaders = headerGetters[type];
const headers = getHeaders ? getHeaders(directories) : {};
if (typeof server === 'string' && server.length > 0) {
try {
const url = new URL(server);
const overrideHeaders = getOverrideHeaders(url.host);
const overrideHeaders = getOverrideHeaders(url.host);
if (overrideHeaders && Object.keys(overrideHeaders).length > 0) {
Object.assign(headers, overrideHeaders);
@ -194,10 +205,11 @@ function setAdditionalHeaders(request, args, server) {
}
}
Object.assign(args.headers, headers);
Object.assign(requestHeaders, headers);
}
module.exports = {
getOverrideHeaders,
setAdditionalHeaders,
setAdditionalHeadersByType,
};

View File

@ -5,7 +5,18 @@ const sanitize = require('sanitize-filename');
const { jsonParser } = require('../express-common');
// Don't forget to add new sources to the SOURCES array
const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm', 'togetherai', 'nomicai', 'cohere'];
const SOURCES = [
'transformers',
'mistral',
'openai',
'extras',
'palm',
'togetherai',
'nomicai',
'cohere',
'ollama',
'llamacpp',
];
/**
* Gets the vector for the given text from the given source.
@ -32,6 +43,10 @@ async function getVector(source, sourceSettings, text, isQuery, directories) {
return require('../vectors/makersuite-vectors').getMakerSuiteVector(text, directories);
case 'cohere':
return require('../vectors/cohere-vectors').getCohereVector(text, isQuery, directories, sourceSettings.model);
case 'llamacpp':
return require('../vectors/llamacpp-vectors').getLlamaCppVector(text, sourceSettings.apiUrl, directories);
case 'ollama':
return require('../vectors/ollama-vectors').getOllamaVector(text, sourceSettings.apiUrl, sourceSettings.model, sourceSettings.keep, directories);
}
throw new Error(`Unknown vector source ${source}`);
@ -73,6 +88,12 @@ async function getBatchVector(source, sourceSettings, texts, isQuery, directorie
case 'cohere':
results.push(...await require('../vectors/cohere-vectors').getCohereBatchVector(batch, isQuery, directories, sourceSettings.model));
break;
case 'llamacpp':
results.push(...await require('../vectors/llamacpp-vectors').getLlamaCppBatchVector(batch, sourceSettings.apiUrl, directories));
break;
case 'ollama':
results.push(...await require('../vectors/ollama-vectors').getOllamaBatchVector(batch, sourceSettings.apiUrl, sourceSettings.model, sourceSettings.keep, directories));
break;
default:
throw new Error(`Unknown vector source ${source}`);
}
@ -251,7 +272,23 @@ function getSourceSettings(source, request) {
return {
model: model,
};
}else {
} else if (source === 'llamacpp') {
const apiUrl = String(request.headers['x-llamacpp-url']);
return {
apiUrl: apiUrl,
};
} else if (source === 'ollama') {
const apiUrl = String(request.headers['x-ollama-url']);
const model = String(request.headers['x-ollama-model']);
const keep = Boolean(request.headers['x-ollama-keep']);
return {
apiUrl: apiUrl,
model: model,
keep: keep,
};
} else {
// Extras API settings to connect to the Extras embeddings provider
let extrasUrl = '';
let extrasKey = '';

View File

@ -0,0 +1,61 @@
const fetch = require('node-fetch').default;
const { setAdditionalHeadersByType } = require('../additional-headers');
const { TEXTGEN_TYPES } = require('../constants');
/**
* Gets the vector for the given text from LlamaCpp
* @param {string[]} texts - The array of texts to get the vectors for
* @param {string} apiUrl - The API URL
* @param {import('../users').UserDirectoryList} directories - The directories object for the user
* @returns {Promise<number[][]>} - The array of vectors for the texts
*/
async function getLlamaCppBatchVector(texts, apiUrl, directories) {
const url = new URL(apiUrl);
url.pathname = '/v1/embeddings';
const headers = {};
setAdditionalHeadersByType(headers, TEXTGEN_TYPES.LLAMACPP, apiUrl, directories);
const response = await fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
...headers,
},
body: JSON.stringify({ input: texts }),
});
if (!response.ok) {
const responseText = await response.text();
throw new Error(`LlamaCpp: Failed to get vector for text: ${response.statusText} ${responseText}`);
}
const data = await response.json();
if (!Array.isArray(data?.data)) {
throw new Error('API response was not an array');
}
// Sort data by x.index to ensure the order is correct
data.data.sort((a, b) => a.index - b.index);
const vectors = data.data.map(x => x.embedding);
return vectors;
}
/**
* Gets the vector for the given text from LlamaCpp
* @param {string} text - The text to get the vector for
* @param {string} apiUrl - The API URL
* @param {import('../users').UserDirectoryList} directories - The directories object for the user
* @returns {Promise<number[]>} - The vector for the text
*/
async function getLlamaCppVector(text, apiUrl, directories) {
const vectors = await getLlamaCppBatchVector([text], apiUrl, directories);
return vectors[0];
}
module.exports = {
getLlamaCppBatchVector,
getLlamaCppVector,
};

View File

@ -0,0 +1,69 @@
const fetch = require('node-fetch').default;
const { setAdditionalHeadersByType } = require('../additional-headers');
const { TEXTGEN_TYPES } = require('../constants');
/**
* Gets the vector for the given text from Ollama
* @param {string[]} texts - The array of texts to get the vectors for
* @param {string} apiUrl - The API URL
* @param {string} model - The model to use
* @param {boolean} keep - Keep the model loaded in memory
* @param {import('../users').UserDirectoryList} directories - The directories object for the user
* @returns {Promise<number[][]>} - The array of vectors for the texts
*/
async function getOllamaBatchVector(texts, apiUrl, model, keep, directories) {
const result = [];
for (const text of texts) {
const vector = await getOllamaVector(text, apiUrl, model, keep, directories);
result.push(vector);
}
return result;
}
/**
* Gets the vector for the given text from Ollama
* @param {string} text - The text to get the vector for
* @param {string} apiUrl - The API URL
* @param {string} model - The model to use
* @param {boolean} keep - Keep the model loaded in memory
* @param {import('../users').UserDirectoryList} directories - The directories object for the user
* @returns {Promise<number[]>} - The vector for the text
*/
async function getOllamaVector(text, apiUrl, model, keep, directories) {
const url = new URL(apiUrl);
url.pathname = '/api/embeddings';
const headers = {};
setAdditionalHeadersByType(headers, TEXTGEN_TYPES.OLLAMA, apiUrl, directories);
const response = await fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
...headers,
},
body: JSON.stringify({
prompt: text,
model: model,
keep_alive: keep ? -1 : undefined,
}),
});
if (!response.ok) {
const responseText = await response.text();
throw new Error(`Ollama: Failed to get vector for text: ${response.statusText} ${responseText}`);
}
const data = await response.json();
if (!Array.isArray(data?.embedding)) {
throw new Error('API response was not an array');
}
return data.embedding;
}
module.exports = {
getOllamaBatchVector,
getOllamaVector,
};