diff --git a/public/scripts/extensions/caption/index.js b/public/scripts/extensions/caption/index.js index 7b7421c8b..121b4d792 100644 --- a/public/scripts/extensions/caption/index.js +++ b/public/scripts/extensions/caption/index.js @@ -398,23 +398,62 @@ jQuery(async function () { $('#caption_wand_container').append(sendButton); $(sendButton).on('click', () => { - const hasCaptionModule = - (modules.includes('caption') && extension_settings.caption.source === 'extras') || - (extension_settings.caption.source === 'multimodal' && extension_settings.caption.multimodal_api === 'openai' && (secret_state[SECRET_KEYS.OPENAI] || extension_settings.caption.allow_reverse_proxy)) || - (extension_settings.caption.source === 'multimodal' && extension_settings.caption.multimodal_api === 'openrouter' && secret_state[SECRET_KEYS.OPENROUTER]) || - (extension_settings.caption.source === 'multimodal' && extension_settings.caption.multimodal_api === 'zerooneai' && secret_state[SECRET_KEYS.ZEROONEAI]) || - (extension_settings.caption.source === 'multimodal' && extension_settings.caption.multimodal_api === 'groq' && secret_state[SECRET_KEYS.GROQ]) || - (extension_settings.caption.source === 'multimodal' && extension_settings.caption.multimodal_api === 'mistral' && (secret_state[SECRET_KEYS.MISTRALAI] || extension_settings.caption.allow_reverse_proxy)) || - (extension_settings.caption.source === 'multimodal' && extension_settings.caption.multimodal_api === 'google' && (secret_state[SECRET_KEYS.MAKERSUITE] || extension_settings.caption.allow_reverse_proxy)) || - (extension_settings.caption.source === 'multimodal' && extension_settings.caption.multimodal_api === 'anthropic' && (secret_state[SECRET_KEYS.CLAUDE] || extension_settings.caption.allow_reverse_proxy)) || - (extension_settings.caption.source === 'multimodal' && extension_settings.caption.multimodal_api === 'ollama' && textgenerationwebui_settings.server_urls[textgen_types.OLLAMA]) || - (extension_settings.caption.source === 'multimodal' && extension_settings.caption.multimodal_api === 'llamacpp' && textgenerationwebui_settings.server_urls[textgen_types.LLAMACPP]) || - (extension_settings.caption.source === 'multimodal' && extension_settings.caption.multimodal_api === 'ooba' && textgenerationwebui_settings.server_urls[textgen_types.OOBA]) || - (extension_settings.caption.source === 'multimodal' && extension_settings.caption.multimodal_api === 'koboldcpp' && textgenerationwebui_settings.server_urls[textgen_types.KOBOLDCPP]) || - (extension_settings.caption.source === 'multimodal' && extension_settings.caption.multimodal_api === 'vllm' && textgenerationwebui_settings.server_urls[textgen_types.VLLM]) || - (extension_settings.caption.source === 'multimodal' && extension_settings.caption.multimodal_api === 'custom') || - extension_settings.caption.source === 'local' || - extension_settings.caption.source === 'horde'; + const hasCaptionModule = (() => { + const settings = extension_settings.caption; + + // Handle non-multimodal sources + if (settings.source === 'extras' && modules.includes('caption')) return true; + if (settings.source === 'local' || settings.source === 'horde') return true; + + // Handle multimodal sources + if (settings.source === 'multimodal') { + const api = settings.multimodal_api; + + // APIs that support reverse proxy + const reverseProxyApis = { + 'openai': SECRET_KEYS.OPENAI, + 'mistral': SECRET_KEYS.MISTRALAI, + 'google': SECRET_KEYS.MAKERSUITE, + 'anthropic': SECRET_KEYS.CLAUDE, + }; + + if (reverseProxyApis[api]) { + if (secret_state[reverseProxyApis[api]] || settings.allow_reverse_proxy) { + return true; + } + } + + const chatCompletionApis = { + 'openrouter': SECRET_KEYS.OPENROUTER, + 'zerooneai': SECRET_KEYS.ZEROONEAI, + 'groq': SECRET_KEYS.GROQ, + 'cohere': SECRET_KEYS.COHERE, + }; + + if (chatCompletionApis[api] && secret_state[chatCompletionApis[api]]) { + return true; + } + + const textCompletionApis = { + 'ollama': textgen_types.OLLAMA, + 'llamacpp': textgen_types.LLAMACPP, + 'ooba': textgen_types.OOBA, + 'koboldcpp': textgen_types.KOBOLDCPP, + 'vllm': textgen_types.VLLM, + }; + + if (textCompletionApis[api] && textgenerationwebui_settings.server_urls[textCompletionApis[api]]) { + return true; + } + + // Custom API doesn't need additional checks + if (api === 'custom') { + return true; + } + } + + return false; + })(); if (!hasCaptionModule) { toastr.error('Choose other captioning source in the extension settings.', 'Captioning is not available'); diff --git a/public/scripts/extensions/caption/settings.html b/public/scripts/extensions/caption/settings.html index 75ba9dea2..dfeb6b5b6 100644 --- a/public/scripts/extensions/caption/settings.html +++ b/public/scripts/extensions/caption/settings.html @@ -19,6 +19,7 @@ + + diff --git a/public/scripts/extensions/shared.js b/public/scripts/extensions/shared.js index a2c15ba41..aa274ec04 100644 --- a/public/scripts/extensions/shared.js +++ b/public/scripts/extensions/shared.js @@ -144,10 +144,14 @@ function throwIfInvalidModel(useReverseProxy) { throw new Error('Google AI Studio API key is not set.'); } - if (extension_settings.caption.multi_modal_api === 'mistral' && !secret_state[SECRET_KEYS.MISTRALAI] && !useReverseProxy) { + if (extension_settings.caption.multimodal_api === 'mistral' && !secret_state[SECRET_KEYS.MISTRALAI] && !useReverseProxy) { throw new Error('Mistral AI API key is not set.'); } + if (extension_settings.caption.multimodal_api === 'cohere' && !secret_state[SECRET_KEYS.COHERE]) { + throw new Error('Cohere API key is not set.'); + } + if (extension_settings.caption.multimodal_api === 'ollama' && !textgenerationwebui_settings.server_urls[textgen_types.OLLAMA]) { throw new Error('Ollama server URL is not set.'); } diff --git a/src/endpoints/openai.js b/src/endpoints/openai.js index 243b758ed..ef7c21ec8 100644 --- a/src/endpoints/openai.js +++ b/src/endpoints/openai.js @@ -62,6 +62,10 @@ router.post('/caption-image', jsonParser, async (request, response) => { key = readSecret(request.user.directories, SECRET_KEYS.GROQ); } + if (request.body.api === 'cohere') { + key = readSecret(request.user.directories, SECRET_KEYS.COHERE); + } + if (!key && !request.body.reverse_proxy && ['custom', 'ooba', 'koboldcpp', 'vllm'].includes(request.body.api) === false) { console.warn('No key found for API', request.body.api); return response.sendStatus(400); @@ -126,6 +130,10 @@ router.post('/caption-image', jsonParser, async (request, response) => { apiUrl = 'https://api.mistral.ai/v1/chat/completions'; } + if (request.body.api === 'cohere') { + apiUrl = 'https://api.cohere.ai/v2/chat'; + } + if (request.body.api === 'ooba') { apiUrl = `${trimV1(request.body.server_url)}/v1/chat/completions`; const imgMessage = body.messages.pop(); @@ -165,7 +173,7 @@ router.post('/caption-image', jsonParser, async (request, response) => { /** @type {any} */ const data = await result.json(); console.info('Multimodal captioning response', data); - const caption = data?.choices[0]?.message?.content; + const caption = data?.choices?.[0]?.message?.content ?? data?.message?.content?.[0]?.text; if (!caption) { return response.status(500).send('No caption found');