diff --git a/post-install.js b/post-install.js index 645085a8a..f787eeba3 100644 --- a/post-install.js +++ b/post-install.js @@ -60,7 +60,8 @@ function convertConfig() { try { console.log(color.blue('Converting config.conf to config.yaml. Your old config.conf will be renamed to config.conf.bak')); const config = require(path.join(process.cwd(), './config.conf')); - fs.renameSync('./config.conf', './config.conf.bak'); + fs.copyFileSync('./config.conf', './config.conf.bak'); + fs.rmSync('./config.conf'); fs.writeFileSync('./config.yaml', yaml.stringify(config)); console.log(color.green('Conversion successful. Please check your config.yaml and fix it if necessary.')); } catch (error) { diff --git a/public/scripts/PromptManager.js b/public/scripts/PromptManager.js index 295c12123..4ab734156 100644 --- a/public/scripts/PromptManager.js +++ b/public/scripts/PromptManager.js @@ -841,7 +841,7 @@ class PromptManager { const promptReferences = this.getPromptOrderForCharacter(this.activeCharacter); for (let i = promptReferences.length - 1; i >= 0; i--) { const reference = promptReferences[i]; - if (-1 === this.serviceSettings.prompts.findIndex(prompt => prompt.identifier === reference.identifier)) { + if (reference && -1 === this.serviceSettings.prompts.findIndex(prompt => prompt.identifier === reference.identifier)) { promptReferences.splice(i, 1); this.log('Removed unused reference: ' + reference.identifier); } diff --git a/public/scripts/extensions/stable-diffusion/comfyWorkflowEditor.html b/public/scripts/extensions/stable-diffusion/comfyWorkflowEditor.html index 784419d39..b510f29d9 100644 --- a/public/scripts/extensions/stable-diffusion/comfyWorkflowEditor.html +++ b/public/scripts/extensions/stable-diffusion/comfyWorkflowEditor.html @@ -19,6 +19,8 @@
  • "%scale%"
  • "%width%"
  • "%height%"
  • +
  • "%user_avatar%"
  • +
  • "%char_avatar%"

  • "%seed%" diff --git a/public/scripts/extensions/stable-diffusion/index.js b/public/scripts/extensions/stable-diffusion/index.js index ee3e9a477..6618b66d4 100644 --- a/public/scripts/extensions/stable-diffusion/index.js +++ b/public/scripts/extensions/stable-diffusion/index.js @@ -2111,21 +2111,11 @@ async function generateMultimodalPrompt(generationType, quietPrompt) { let avatarUrl; if (generationType == generationMode.USER_MULTIMODAL) { - avatarUrl = getUserAvatar(user_avatar); + avatarUrl = getUserAvatarUrl(); } if (generationType == generationMode.CHARACTER_MULTIMODAL || generationType === generationMode.FACE_MULTIMODAL) { - const context = getContext(); - - if (context.groupId) { - const groupMembers = context.groups.find(x => x.id === context.groupId)?.members; - const lastMessageAvatar = context.chat?.filter(x => !x.is_system && !x.is_user)?.slice(-1)[0]?.original_avatar; - const randomMemberAvatar = Array.isArray(groupMembers) ? groupMembers[Math.floor(Math.random() * groupMembers.length)]?.avatar : null; - const avatarToUse = lastMessageAvatar || randomMemberAvatar; - avatarUrl = formatCharacterAvatar(avatarToUse); - } else { - avatarUrl = getCharacterAvatar(context.characterId); - } + avatarUrl = getCharacterAvatarUrl(); } try { @@ -2152,6 +2142,24 @@ async function generateMultimodalPrompt(generationType, quietPrompt) { } } +function getCharacterAvatarUrl() { + const context = getContext(); + + if (context.groupId) { + const groupMembers = context.groups.find(x => x.id === context.groupId)?.members; + const lastMessageAvatar = context.chat?.filter(x => !x.is_system && !x.is_user)?.slice(-1)[0]?.original_avatar; + const randomMemberAvatar = Array.isArray(groupMembers) ? groupMembers[Math.floor(Math.random() * groupMembers.length)]?.avatar : null; + const avatarToUse = lastMessageAvatar || randomMemberAvatar; + return formatCharacterAvatar(avatarToUse); + } else { + return getCharacterAvatar(context.characterId); + } +} + +function getUserAvatarUrl() { + return getUserAvatar(user_avatar); +} + /** * Generates a prompt using the main LLM API. * @param {string} quietPrompt - The prompt to use for the image generation. @@ -2636,6 +2644,22 @@ async function generateComfyImage(prompt, negativePrompt) { (extension_settings.sd.comfy_placeholders ?? []).forEach(ph => { workflow = workflow.replace(`"%${ph.find}%"`, JSON.stringify(substituteParams(ph.replace))); }); + if (/%user_avatar%/gi.test(workflow)) { + const response = await fetch(getUserAvatarUrl()); + if (response.ok) { + const avatarBlob = await response.blob(); + const avatarBase64 = await getBase64Async(avatarBlob); + workflow = workflow.replace('"%user_avatar%"', JSON.stringify(avatarBase64)); + } + } + if (/%char_avatar%/gi.test(workflow)) { + const response = await fetch(getCharacterAvatarUrl()); + if (response.ok) { + const avatarBlob = await response.blob(); + const avatarBase64 = await getBase64Async(avatarBlob); + workflow = workflow.replace('"%char_avatar%"', JSON.stringify(avatarBase64)); + } + } console.log(`{ "prompt": ${workflow} }`); @@ -2649,6 +2673,10 @@ async function generateComfyImage(prompt, negativePrompt) { }`, }), }); + if (!promptResult.ok) { + const text = await promptResult.text(); + throw new Error(text); + } return { format: 'png', data: await promptResult.text() }; } diff --git a/public/scripts/extensions/vectors/index.js b/public/scripts/extensions/vectors/index.js index 377878d8b..3d6238b11 100644 --- a/public/scripts/extensions/vectors/index.js +++ b/public/scripts/extensions/vectors/index.js @@ -35,6 +35,7 @@ const settings = { include_wi: false, togetherai_model: 'togethercomputer/m2-bert-80M-32k-retrieval', openai_model: 'text-embedding-ada-002', + cohere_model: 'embed-english-v3.0', summarize: false, summarize_sent: false, summary_source: 'main', @@ -68,6 +69,15 @@ const settings = { const moduleWorker = new ModuleWorkerWrapper(synchronizeChat); +/** + * Gets the Collection ID for a file embedded in the chat. + * @param {string} fileUrl URL of the file + * @returns {string} Collection ID + */ +function getFileCollectionId(fileUrl) { + return `file_${getStringHash(fileUrl)}`; +} + async function onVectorizeAllClick() { try { if (!settings.enabled_chats) { @@ -308,7 +318,7 @@ async function processFiles(chat) { const dataBankCollectionIds = []; for (const file of dataBank) { - const collectionId = `file_${getStringHash(file.url)}`; + const collectionId = getFileCollectionId(file.url); const hashesInCollection = await getSavedHashes(collectionId); dataBankCollectionIds.push(collectionId); @@ -354,7 +364,7 @@ async function processFiles(chat) { const fileName = message.extra.file.name; const fileUrl = message.extra.file.url; - const collectionId = `file_${getStringHash(fileUrl)}`; + const collectionId = getFileCollectionId(fileUrl); const hashesInCollection = await getSavedHashes(collectionId); // File is already in the collection @@ -598,6 +608,9 @@ function getVectorHeaders() { case 'openai': addOpenAiHeaders(headers); break; + case 'cohere': + addCohereHeaders(headers); + break; default: break; } @@ -636,6 +649,16 @@ function addOpenAiHeaders(headers) { }); } +/** + * Add headers for the Cohere API source. + * @param {object} headers Header object + */ +function addCohereHeaders(headers) { + Object.assign(headers, { + 'X-Cohere-Model': extension_settings.vectors.cohere_model, + }); +} + /** * Inserts vector items into a collection * @param {string} collectionId - The collection to insert into @@ -647,7 +670,8 @@ async function insertVectorItems(collectionId, items) { settings.source === 'palm' && !secret_state[SECRET_KEYS.MAKERSUITE] || settings.source === 'mistral' && !secret_state[SECRET_KEYS.MISTRALAI] || settings.source === 'togetherai' && !secret_state[SECRET_KEYS.TOGETHERAI] || - settings.source === 'nomicai' && !secret_state[SECRET_KEYS.NOMICAI]) { + settings.source === 'nomicai' && !secret_state[SECRET_KEYS.NOMICAI] || + settings.source === 'cohere' && !secret_state[SECRET_KEYS.COHERE]) { throw new Error('Vectors: API key missing', { cause: 'api_key_missing' }); } @@ -760,7 +784,7 @@ async function purgeFileVectorIndex(fileUrl) { } console.log(`Vectors: Purging file vector index for ${fileUrl}`); - const collectionId = `file_${getStringHash(fileUrl)}`; + const collectionId = getFileCollectionId(fileUrl); const response = await fetch('/api/vector/purge', { method: 'POST', @@ -816,6 +840,7 @@ function toggleSettings() { $('#vectors_chats_settings').toggle(!!settings.enabled_chats); $('#together_vectorsModel').toggle(settings.source === 'togetherai'); $('#openai_vectorsModel').toggle(settings.source === 'openai'); + $('#cohere_vectorsModel').toggle(settings.source === 'cohere'); $('#nomicai_apiKey').toggle(settings.source === 'nomicai'); } @@ -859,6 +884,42 @@ async function onViewStatsClick() { } +async function onVectorizeAllFilesClick() { + try { + const dataBank = getDataBankAttachments(); + const chatAttachments = getContext().chat.filter(x => x.extra?.file).map(x => x.extra.file); + const allFiles = [...dataBank, ...chatAttachments]; + + for (const file of allFiles) { + const text = await getFileAttachment(file.url); + const collectionId = getFileCollectionId(file.url); + await vectorizeFile(text, file.name, collectionId, settings.chunk_size); + } + + toastr.success('All files vectorized', 'Vectorization successful'); + } catch (error) { + console.error('Vectors: Failed to vectorize all files', error); + toastr.error('Failed to vectorize all files', 'Vectorization failed'); + } +} + +async function onPurgeFilesClick() { + try { + const dataBank = getDataBankAttachments(); + const chatAttachments = getContext().chat.filter(x => x.extra?.file).map(x => x.extra.file); + const allFiles = [...dataBank, ...chatAttachments]; + + for (const file of allFiles) { + await purgeFileVectorIndex(file.url); + } + + toastr.success('All files purged', 'Purge successful'); + } catch (error) { + console.error('Vectors: Failed to purge all files', error); + toastr.error('Failed to purge all files', 'Purge failed'); + } +} + jQuery(async () => { if (!extension_settings.vectors) { extension_settings.vectors = settings; @@ -913,6 +974,12 @@ jQuery(async () => { Object.assign(extension_settings.vectors, settings); saveSettingsDebounced(); }); + $('#vectors_cohere_model').val(settings.cohere_model).on('change', () => { + $('#vectors_modelWarning').show(); + settings.cohere_model = String($('#vectors_cohere_model').val()); + Object.assign(extension_settings.vectors, settings); + saveSettingsDebounced(); + }); $('#vectors_template').val(settings.template).on('input', () => { settings.template = String($('#vectors_template').val()); Object.assign(extension_settings.vectors, settings); @@ -947,6 +1014,8 @@ jQuery(async () => { $('#vectors_vectorize_all').on('click', onVectorizeAllClick); $('#vectors_purge').on('click', onPurgeClick); $('#vectors_view_stats').on('click', onViewStatsClick); + $('#vectors_files_vectorize_all').on('click', onVectorizeAllFilesClick); + $('#vectors_files_purge').on('click', onPurgeFilesClick); $('#vectors_size_threshold').val(settings.size_threshold).on('input', () => { settings.size_threshold = Number($('#vectors_size_threshold').val()); diff --git a/public/scripts/extensions/vectors/settings.html b/public/scripts/extensions/vectors/settings.html index 98c807cd9..cdb91981b 100644 --- a/public/scripts/extensions/vectors/settings.html +++ b/public/scripts/extensions/vectors/settings.html @@ -10,13 +10,14 @@ Vectorization Source
    @@ -29,6 +30,20 @@
    +
    + + +
    +
    + + +

    diff --git a/public/scripts/instruct-mode.js b/public/scripts/instruct-mode.js index ba1b1506e..3837cd7eb 100644 --- a/public/scripts/instruct-mode.js +++ b/public/scripts/instruct-mode.js @@ -354,7 +354,9 @@ export function formatInstructModeSystemPrompt(systemPrompt) { const separator = power_user.instruct.wrap ? '\n' : ''; if (power_user.instruct.system_sequence_prefix) { - systemPrompt = power_user.instruct.system_sequence_prefix + separator + systemPrompt; + // TODO: Replace with a proper 'System' prompt entity name input + const prefix = power_user.instruct.system_sequence_prefix.replace(/{{name}}/gi, 'System'); + systemPrompt = prefix + separator + systemPrompt; } if (power_user.instruct.system_sequence_suffix) { diff --git a/public/scripts/popup.js b/public/scripts/popup.js index 4e75431f4..c69aa6e60 100644 --- a/public/scripts/popup.js +++ b/public/scripts/popup.js @@ -119,11 +119,15 @@ export class Popup { const keyListener = (evt) => { switch (evt.key) { case 'Escape': { - evt.preventDefault(); - evt.stopPropagation(); - this.completeCancelled(); - window.removeEventListener('keydown', keyListenerBound); - break; + // does it really matter where we check? + const topModal = document.elementFromPoint(window.innerWidth / 2, window.innerHeight / 2)?.closest('.shadow_popup'); + if (topModal == this.dom) { + evt.preventDefault(); + evt.stopPropagation(); + this.completeCancelled(); + window.removeEventListener('keydown', keyListenerBound); + break; + } } } }; diff --git a/src/endpoints/assets.js b/src/endpoints/assets.js index a9dc317d6..78f5270a7 100644 --- a/src/endpoints/assets.js +++ b/src/endpoints/assets.js @@ -227,7 +227,8 @@ router.post('/download', jsonParser, async (request, response) => { // Move into asset place console.debug('Download finished, moving file from', temp_path, 'to', file_path); - fs.renameSync(temp_path, file_path); + fs.copyFileSync(temp_path, file_path); + fs.rmSync(temp_path); response.sendStatus(200); } catch (error) { diff --git a/src/endpoints/backgrounds.js b/src/endpoints/backgrounds.js index 33419ef4f..b8965ab5f 100644 --- a/src/endpoints/backgrounds.js +++ b/src/endpoints/backgrounds.js @@ -51,7 +51,8 @@ router.post('/rename', jsonParser, function (request, response) { return response.sendStatus(400); } - fs.renameSync(oldFileName, newFileName); + fs.copyFileSync(oldFileName, newFileName); + fs.rmSync(oldFileName); invalidateThumbnail(request.user.directories, 'bg', request.body.old_bg); return response.send('ok'); }); @@ -63,7 +64,8 @@ router.post('/upload', urlencodedParser, function (request, response) { const filename = request.file.originalname; try { - fs.renameSync(img_path, path.join(request.user.directories.backgrounds, filename)); + fs.copyFileSync(img_path, path.join(request.user.directories.backgrounds, filename)); + fs.rmSync(img_path); invalidateThumbnail(request.user.directories, 'bg', filename); response.send(filename); } catch (err) { diff --git a/src/endpoints/characters.js b/src/endpoints/characters.js index f9ff18688..556a8fecc 100644 --- a/src/endpoints/characters.js +++ b/src/endpoints/characters.js @@ -680,7 +680,8 @@ router.post('/rename', jsonParser, async function (request, response) { // Rename chats folder if (fs.existsSync(oldChatsPath) && !fs.existsSync(newChatsPath)) { - fs.renameSync(oldChatsPath, newChatsPath); + fs.cpSync(oldChatsPath, newChatsPath, { recursive: true }); + fs.rmSync(oldChatsPath, { recursive: true, force: true }); } // Remove the old character file diff --git a/src/endpoints/chats.js b/src/endpoints/chats.js index 49cf98e01..ff55d3ff0 100644 --- a/src/endpoints/chats.js +++ b/src/endpoints/chats.js @@ -213,8 +213,9 @@ router.post('/rename', jsonParser, async function (request, response) { return response.status(400).send({ error: true }); } + fs.copyFileSync(pathToOriginalFile, pathToRenamedFile); + fs.rmSync(pathToOriginalFile); console.log('Successfully renamed.'); - fs.renameSync(pathToOriginalFile, pathToRenamedFile); return response.send({ ok: true }); }); diff --git a/src/endpoints/vectors.js b/src/endpoints/vectors.js index b495de752..990796fb1 100644 --- a/src/endpoints/vectors.js +++ b/src/endpoints/vectors.js @@ -12,23 +12,26 @@ const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm', 'togethe * @param {string} source - The source of the vector * @param {Object} sourceSettings - Settings for the source, if it needs any * @param {string} text - The text to get the vector for + * @param {boolean} isQuery - If the text is a query for embedding search * @param {import('../users').UserDirectoryList} directories - The directories object for the user * @returns {Promise} - The vector for the text */ -async function getVector(source, sourceSettings, text, directories) { +async function getVector(source, sourceSettings, text, isQuery, directories) { switch (source) { case 'nomicai': - return require('../nomicai-vectors').getNomicAIVector(text, source, directories); + return require('../vectors/nomicai-vectors').getNomicAIVector(text, source, directories); case 'togetherai': case 'mistral': case 'openai': - return require('../openai-vectors').getOpenAIVector(text, source, directories, sourceSettings.model); + return require('../vectors/openai-vectors').getOpenAIVector(text, source, directories, sourceSettings.model); case 'transformers': - return require('../embedding').getTransformersVector(text); + return require('../vectors/embedding').getTransformersVector(text); case 'extras': - return require('../extras-vectors').getExtrasVector(text, sourceSettings.extrasUrl, sourceSettings.extrasKey); + return require('../vectors/extras-vectors').getExtrasVector(text, sourceSettings.extrasUrl, sourceSettings.extrasKey); case 'palm': - return require('../makersuite-vectors').getMakerSuiteVector(text, directories); + return require('../vectors/makersuite-vectors').getMakerSuiteVector(text, directories); + case 'cohere': + return require('../vectors/cohere-vectors').getCohereVector(text, isQuery, directories, sourceSettings.model); } throw new Error(`Unknown vector source ${source}`); @@ -39,10 +42,11 @@ async function getVector(source, sourceSettings, text, directories) { * @param {string} source - The source of the vector * @param {Object} sourceSettings - Settings for the source, if it needs any * @param {string[]} texts - The array of texts to get the vector for + * @param {boolean} isQuery - If the text is a query for embedding search * @param {import('../users').UserDirectoryList} directories - The directories object for the user * @returns {Promise} - The array of vectors for the texts */ -async function getBatchVector(source, sourceSettings, texts, directories) { +async function getBatchVector(source, sourceSettings, texts, isQuery, directories) { const batchSize = 10; const batches = Array(Math.ceil(texts.length / batchSize)).fill(undefined).map((_, i) => texts.slice(i * batchSize, i * batchSize + batchSize)); @@ -50,21 +54,24 @@ async function getBatchVector(source, sourceSettings, texts, directories) { for (let batch of batches) { switch (source) { case 'nomicai': - results.push(...await require('../nomicai-vectors').getNomicAIBatchVector(batch, source, directories)); + results.push(...await require('../vectors/nomicai-vectors').getNomicAIBatchVector(batch, source, directories)); break; case 'togetherai': case 'mistral': case 'openai': - results.push(...await require('../openai-vectors').getOpenAIBatchVector(batch, source, directories, sourceSettings.model)); + results.push(...await require('../vectors/openai-vectors').getOpenAIBatchVector(batch, source, directories, sourceSettings.model)); break; case 'transformers': - results.push(...await require('../embedding').getTransformersBatchVector(batch)); + results.push(...await require('../vectors/embedding').getTransformersBatchVector(batch)); break; case 'extras': - results.push(...await require('../extras-vectors').getExtrasBatchVector(batch, sourceSettings.extrasUrl, sourceSettings.extrasKey)); + results.push(...await require('../vectors/extras-vectors').getExtrasBatchVector(batch, sourceSettings.extrasUrl, sourceSettings.extrasKey)); break; case 'palm': - results.push(...await require('../makersuite-vectors').getMakerSuiteBatchVector(batch, directories)); + results.push(...await require('../vectors/makersuite-vectors').getMakerSuiteBatchVector(batch, directories)); + break; + case 'cohere': + results.push(...await require('../vectors/cohere-vectors').getCohereBatchVector(batch, isQuery, directories, sourceSettings.model)); break; default: throw new Error(`Unknown vector source ${source}`); @@ -106,7 +113,7 @@ async function insertVectorItems(directories, collectionId, source, sourceSettin await store.beginUpdate(); - const vectors = await getBatchVector(source, sourceSettings, items.map(x => x.text), directories); + const vectors = await getBatchVector(source, sourceSettings, items.map(x => x.text), false, directories); for (let i = 0; i < items.length; i++) { const item = items[i]; @@ -165,7 +172,7 @@ async function deleteVectorItems(directories, collectionId, source, hashes) { */ async function queryCollection(directories, collectionId, source, sourceSettings, searchText, topK) { const store = await getIndex(directories, collectionId, source); - const vector = await getVector(source, sourceSettings, searchText, directories); + const vector = await getVector(source, sourceSettings, searchText, true, directories); const result = await store.queryItems(vector, topK); const metadata = result.map(x => x.item.metadata); @@ -184,7 +191,7 @@ async function queryCollection(directories, collectionId, source, sourceSettings * @returns {Promise>} - The top K results from each collection */ async function multiQueryCollection(directories, collectionIds, source, sourceSettings, searchText, topK) { - const vector = await getVector(source, sourceSettings, searchText, directories); + const vector = await getVector(source, sourceSettings, searchText, true, directories); const results = []; for (const collectionId of collectionIds) { @@ -223,18 +230,24 @@ async function multiQueryCollection(directories, collectionIds, source, sourceSe */ function getSourceSettings(source, request) { if (source === 'togetherai') { - let model = String(request.headers['x-togetherai-model']); + const model = String(request.headers['x-togetherai-model']); return { model: model, }; } else if (source === 'openai') { - let model = String(request.headers['x-openai-model']); + const model = String(request.headers['x-openai-model']); return { model: model, }; - } else { + } else if (source === 'cohere') { + const model = String(request.headers['x-cohere-model']); + + return { + model: model, + }; + }else { // Extras API settings to connect to the Extras embeddings provider let extrasUrl = ''; let extrasKey = ''; diff --git a/src/users.js b/src/users.js index 8781334e3..831ff09d5 100644 --- a/src/users.js +++ b/src/users.js @@ -286,12 +286,22 @@ async function migrateUserData() { // Copy the file to the new location fs.cpSync(migration.old, migration.new, { force: true }); // Move the file to the backup location - fs.renameSync(migration.old, path.join(backupDirectory, path.basename(migration.old))); + fs.cpSync( + migration.old, + path.join(backupDirectory, path.basename(migration.old)), + { recursive: true, force: true } + ); + fs.rmSync(migration.old, { recursive: true, force: true }); } else { // Copy the directory to the new location fs.cpSync(migration.old, migration.new, { recursive: true, force: true }); // Move the directory to the backup location - fs.renameSync(migration.old, path.join(backupDirectory, path.basename(migration.old))); + fs.cpSync( + migration.old, + path.join(backupDirectory, path.basename(migration.old)), + { recursive: true, force: true } + ); + fs.rmSync(migration.old, { recursive: true, force: true }); } } catch (error) { console.error(color.red(`Error migrating ${migration.old} to ${migration.new}:`), error.message); diff --git a/src/vectors/cohere-vectors.js b/src/vectors/cohere-vectors.js new file mode 100644 index 000000000..1ec01130c --- /dev/null +++ b/src/vectors/cohere-vectors.js @@ -0,0 +1,65 @@ +const fetch = require('node-fetch').default; +const { SECRET_KEYS, readSecret } = require('../endpoints/secrets'); + +/** + * Gets the vector for the given text batch from an OpenAI compatible endpoint. + * @param {string[]} texts - The array of texts to get the vector for + * @param {boolean} isQuery - If the text is a query for embedding search + * @param {import('../users').UserDirectoryList} directories - The directories object for the user + * @param {string} model - The model to use for the embedding + * @returns {Promise} - The array of vectors for the texts + */ +async function getCohereBatchVector(texts, isQuery, directories, model) { + const key = readSecret(directories, SECRET_KEYS.COHERE); + + if (!key) { + console.log('No API key found'); + throw new Error('No API key found'); + } + + const response = await fetch('https://api.cohere.ai/v1/embed', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${key}`, + }, + body: JSON.stringify({ + texts: texts, + model: model, + input_type: isQuery ? 'search_query' : 'search_document', + truncate: 'END', + }), + }); + + if (!response.ok) { + const text = await response.text(); + console.log('API request failed', response.statusText, text); + throw new Error('API request failed'); + } + + const data = await response.json(); + if (!Array.isArray(data?.embeddings)) { + console.log('API response was not an array'); + throw new Error('API response was not an array'); + } + + return data.embeddings; +} + +/** + * Gets the vector for the given text from an OpenAI compatible endpoint. + * @param {string} text - The text to get the vector for + * @param {boolean} isQuery - If the text is a query for embedding search + * @param {import('../users').UserDirectoryList} directories - The directories object for the user + * @param {string} model - The model to use for the embedding + * @returns {Promise} - The vector for the text + */ +async function getCohereVector(text, isQuery, directories, model) { + const vectors = await getCohereBatchVector([text], isQuery, directories, model); + return vectors[0]; +} + +module.exports = { + getCohereBatchVector, + getCohereVector, +}; diff --git a/src/embedding.js b/src/vectors/embedding.js similarity index 94% rename from src/embedding.js rename to src/vectors/embedding.js index eabc0cc43..3f02e07db 100644 --- a/src/embedding.js +++ b/src/vectors/embedding.js @@ -6,7 +6,7 @@ const TASK = 'feature-extraction'; * @returns {Promise} - The vectorized text in form of an array of numbers */ async function getTransformersVector(text) { - const module = await import('./transformers.mjs'); + const module = await import('../transformers.mjs'); const pipe = await module.default.getPipeline(TASK); const result = await pipe(text, { pooling: 'mean', normalize: true }); const vector = Array.from(result.data); diff --git a/src/extras-vectors.js b/src/vectors/extras-vectors.js similarity index 100% rename from src/extras-vectors.js rename to src/vectors/extras-vectors.js diff --git a/src/makersuite-vectors.js b/src/vectors/makersuite-vectors.js similarity index 85% rename from src/makersuite-vectors.js rename to src/vectors/makersuite-vectors.js index 279e7c253..b0ea928ff 100644 --- a/src/makersuite-vectors.js +++ b/src/vectors/makersuite-vectors.js @@ -1,10 +1,10 @@ const fetch = require('node-fetch').default; -const { SECRET_KEYS, readSecret } = require('./endpoints/secrets'); +const { SECRET_KEYS, readSecret } = require('../endpoints/secrets'); /** * Gets the vector for the given text from gecko model * @param {string[]} texts - The array of texts to get the vector for - * @param {import('./users').UserDirectoryList} directories - The directories object for the user + * @param {import('../users').UserDirectoryList} directories - The directories object for the user * @returns {Promise} - The array of vectors for the texts */ async function getMakerSuiteBatchVector(texts, directories) { @@ -16,7 +16,7 @@ async function getMakerSuiteBatchVector(texts, directories) { /** * Gets the vector for the given text from PaLM gecko model * @param {string} text - The text to get the vector for - * @param {import('./users').UserDirectoryList} directories - The directories object for the user + * @param {import('../users').UserDirectoryList} directories - The directories object for the user * @returns {Promise} - The vector for the text */ async function getMakerSuiteVector(text, directories) { diff --git a/src/nomicai-vectors.js b/src/vectors/nomicai-vectors.js similarity index 88% rename from src/nomicai-vectors.js rename to src/vectors/nomicai-vectors.js index 2ac682b7d..29b322926 100644 --- a/src/nomicai-vectors.js +++ b/src/vectors/nomicai-vectors.js @@ -1,5 +1,5 @@ const fetch = require('node-fetch').default; -const { SECRET_KEYS, readSecret } = require('./endpoints/secrets'); +const { SECRET_KEYS, readSecret } = require('../endpoints/secrets'); const SOURCES = { 'nomicai': { @@ -13,7 +13,7 @@ const SOURCES = { * Gets the vector for the given text batch from an OpenAI compatible endpoint. * @param {string[]} texts - The array of texts to get the vector for * @param {string} source - The source of the vector - * @param {import('./users').UserDirectoryList} directories - The directories object for the user + * @param {import('../users').UserDirectoryList} directories - The directories object for the user * @returns {Promise} - The array of vectors for the texts */ async function getNomicAIBatchVector(texts, source, directories) { @@ -64,7 +64,7 @@ async function getNomicAIBatchVector(texts, source, directories) { * Gets the vector for the given text from an OpenAI compatible endpoint. * @param {string} text - The text to get the vector for * @param {string} source - The source of the vector - * @param {import('./users').UserDirectoryList} directories - The directories object for the user + * @param {import('../users').UserDirectoryList} directories - The directories object for the user * @returns {Promise} - The vector for the text */ async function getNomicAIVector(text, source, directories) { diff --git a/src/openai-vectors.js b/src/vectors/openai-vectors.js similarity index 91% rename from src/openai-vectors.js rename to src/vectors/openai-vectors.js index d748658bb..6a30d3f2b 100644 --- a/src/openai-vectors.js +++ b/src/vectors/openai-vectors.js @@ -1,5 +1,5 @@ const fetch = require('node-fetch').default; -const { SECRET_KEYS, readSecret } = require('./endpoints/secrets'); +const { SECRET_KEYS, readSecret } = require('../endpoints/secrets'); const SOURCES = { 'togetherai': { @@ -23,7 +23,7 @@ const SOURCES = { * Gets the vector for the given text batch from an OpenAI compatible endpoint. * @param {string[]} texts - The array of texts to get the vector for * @param {string} source - The source of the vector - * @param {import('./users').UserDirectoryList} directories - The directories object for the user + * @param {import('../users').UserDirectoryList} directories - The directories object for the user * @param {string} model - The model to use for the embedding * @returns {Promise} - The array of vectors for the texts */ @@ -79,7 +79,7 @@ async function getOpenAIBatchVector(texts, source, directories, model = '') { * Gets the vector for the given text from an OpenAI compatible endpoint. * @param {string} text - The text to get the vector for * @param {string} source - The source of the vector - * @param {import('./users').UserDirectoryList} directories - The directories object for the user + * @param {import('../users').UserDirectoryList} directories - The directories object for the user * @param {string} model - The model to use for the embedding * @returns {Promise} - The vector for the text */