diff --git a/public/script.js b/public/script.js
index 8252c0b79..2b775650f 100644
--- a/public/script.js
+++ b/public/script.js
@@ -4836,7 +4836,7 @@ function extractMessageFromData(data) {
case 'novel':
return data.output;
case 'openai':
- return data?.choices?.[0]?.message?.content ?? data?.choices?.[0]?.text ?? '';
+ return data?.choices?.[0]?.message?.content ?? data?.choices?.[0]?.text ?? data?.text ?? '';
default:
return '';
}
@@ -8187,6 +8187,11 @@ const CONNECT_API_MAP = {
button: '#api_button_openai',
source: chat_completion_sources.CUSTOM,
},
+ 'cohere': {
+ selected: 'cohere',
+ button: '#api_button_openai',
+ source: chat_completion_sources.COHERE,
+ },
'infermaticai': {
selected: 'textgenerationwebui',
button: '#api_button_textgenerationwebui',
diff --git a/public/scripts/RossAscends-mods.js b/public/scripts/RossAscends-mods.js
index 054641268..87cbbff2c 100644
--- a/public/scripts/RossAscends-mods.js
+++ b/public/scripts/RossAscends-mods.js
@@ -350,6 +350,7 @@ function RA_autoconnect(PrevApi) {
|| (secret_state[SECRET_KEYS.AI21] && oai_settings.chat_completion_source == chat_completion_sources.AI21)
|| (secret_state[SECRET_KEYS.MAKERSUITE] && oai_settings.chat_completion_source == chat_completion_sources.MAKERSUITE)
|| (secret_state[SECRET_KEYS.MISTRALAI] && oai_settings.chat_completion_source == chat_completion_sources.MISTRALAI)
+ || (secret_state[SECRET_KEYS.COHERE] && oai_settings.chat_completion_source == chat_completion_sources.COHERE)
|| (isValidUrl(oai_settings.custom_url) && oai_settings.chat_completion_source == chat_completion_sources.CUSTOM)
) {
$('#api_button_openai').trigger('click');
diff --git a/public/scripts/openai.js b/public/scripts/openai.js
index e7b83c285..58060ccb9 100644
--- a/public/scripts/openai.js
+++ b/public/scripts/openai.js
@@ -171,6 +171,7 @@ export const chat_completion_sources = {
MAKERSUITE: 'makersuite',
MISTRALAI: 'mistralai',
CUSTOM: 'custom',
+ COHERE: 'cohere',
};
const character_names_behavior = {
@@ -230,6 +231,7 @@ const default_settings = {
google_model: 'gemini-pro',
ai21_model: 'j2-ultra',
mistralai_model: 'mistral-medium-latest',
+ cohere_model: 'command-r',
custom_model: '',
custom_url: '',
custom_include_body: '',
@@ -298,6 +300,7 @@ const oai_settings = {
google_model: 'gemini-pro',
ai21_model: 'j2-ultra',
mistralai_model: 'mistral-medium-latest',
+ cohere_model: 'command-r',
custom_model: '',
custom_url: '',
custom_include_body: '',
@@ -1384,6 +1387,8 @@ function getChatCompletionModel() {
return oai_settings.mistralai_model;
case chat_completion_sources.CUSTOM:
return oai_settings.custom_model;
+ case chat_completion_sources.COHERE:
+ return oai_settings.cohere_model;
default:
throw new Error(`Unknown chat completion source: ${oai_settings.chat_completion_source}`);
}
@@ -1603,6 +1608,7 @@ async function sendOpenAIRequest(type, messages, signal) {
const isOAI = oai_settings.chat_completion_source == chat_completion_sources.OPENAI;
const isMistral = oai_settings.chat_completion_source == chat_completion_sources.MISTRALAI;
const isCustom = oai_settings.chat_completion_source == chat_completion_sources.CUSTOM;
+ const isCohere = oai_settings.chat_completion_source == chat_completion_sources.COHERE;
const isTextCompletion = (isOAI && textCompletionModels.includes(oai_settings.openai_model)) || (isOpenRouter && oai_settings.openrouter_force_instruct && power_user.instruct.enabled);
const isQuiet = type === 'quiet';
const isImpersonate = type === 'impersonate';
@@ -1737,7 +1743,17 @@ async function sendOpenAIRequest(type, messages, signal) {
generate_data['custom_include_headers'] = oai_settings.custom_include_headers;
}
- if ((isOAI || isOpenRouter || isMistral || isCustom) && oai_settings.seed >= 0) {
+ if (isCohere) {
+ // Clamp to 0.01 -> 0.99
+ generate_data['top_p'] = Math.min(Math.max(Number(oai_settings.top_p_openai), 0.01), 0.99);
+ generate_data['top_k'] = Number(oai_settings.top_k_openai);
+ // Clamp to 0 -> 1
+ generate_data['frequency_penalty'] = Math.min(Math.max(Number(oai_settings.freq_pen_openai), 0), 1);
+ generate_data['presence_penalty'] = Math.min(Math.max(Number(oai_settings.pres_pen_openai), 0), 1);
+ generate_data['stop'] = getCustomStoppingStrings(5);
+ }
+
+ if ((isOAI || isOpenRouter || isMistral || isCustom || isCohere) && oai_settings.seed >= 0) {
generate_data['seed'] = oai_settings.seed;
}
@@ -2597,6 +2613,7 @@ function loadOpenAISettings(data, settings) {
oai_settings.openrouter_force_instruct = settings.openrouter_force_instruct ?? default_settings.openrouter_force_instruct;
oai_settings.ai21_model = settings.ai21_model ?? default_settings.ai21_model;
oai_settings.mistralai_model = settings.mistralai_model ?? default_settings.mistralai_model;
+ oai_settings.cohere_model = settings.cohere_model ?? default_settings.cohere_model;
oai_settings.custom_model = settings.custom_model ?? default_settings.custom_model;
oai_settings.custom_url = settings.custom_url ?? default_settings.custom_url;
oai_settings.custom_include_body = settings.custom_include_body ?? default_settings.custom_include_body;
@@ -2657,6 +2674,8 @@ function loadOpenAISettings(data, settings) {
$(`#model_ai21_select option[value="${oai_settings.ai21_model}"`).attr('selected', true);
$('#model_mistralai_select').val(oai_settings.mistralai_model);
$(`#model_mistralai_select option[value="${oai_settings.mistralai_model}"`).attr('selected', true);
+ $('#model_cohere_select').val(oai_settings.cohere_model);
+ $(`#model_cohere_select option[value="${oai_settings.cohere_model}"`).attr('selected', true);
$('#custom_model_id').val(oai_settings.custom_model);
$('#custom_api_url_text').val(oai_settings.custom_url);
$('#openai_max_context').val(oai_settings.openai_max_context);
@@ -2893,6 +2912,7 @@ async function saveOpenAIPreset(name, settings, triggerUi = true) {
openrouter_sort_models: settings.openrouter_sort_models,
ai21_model: settings.ai21_model,
mistralai_model: settings.mistralai_model,
+ cohere_model: settings.cohere_model,
custom_model: settings.custom_model,
custom_url: settings.custom_url,
custom_include_body: settings.custom_include_body,
@@ -3281,6 +3301,7 @@ function onSettingsPresetChange() {
openrouter_sort_models: ['#openrouter_sort_models', 'openrouter_sort_models', false],
ai21_model: ['#model_ai21_select', 'ai21_model', false],
mistralai_model: ['#model_mistralai_select', 'mistralai_model', false],
+ cohere_model: ['#model_cohere_select', 'cohere_model', false],
custom_model: ['#custom_model_id', 'custom_model', false],
custom_url: ['#custom_api_url_text', 'custom_url', false],
custom_include_body: ['#custom_include_body', 'custom_include_body', false],
@@ -3496,6 +3517,11 @@ async function onModelChange() {
$('#model_mistralai_select').val(oai_settings.mistralai_model);
}
+ if ($(this).is('#model_cohere_select')) {
+ console.log('Cohere model changed to', value);
+ oai_settings.cohere_model = value;
+ }
+
if (value && $(this).is('#model_custom_select')) {
console.log('Custom model changed to', value);
oai_settings.custom_model = value;
@@ -3619,6 +3645,26 @@ async function onModelChange() {
$('#temp_openai').attr('max', claude_max_temp).val(oai_settings.temp_openai).trigger('input');
}
+ if (oai_settings.chat_completion_source === chat_completion_sources.COHERE) {
+ if (oai_settings.max_context_unlocked) {
+ $('#openai_max_context').attr('max', unlocked_max);
+ }
+ else if (['command-light', 'command'].includes(oai_settings.cohere_model)) {
+ $('#openai_max_context').attr('max', max_4k);
+ }
+ else if (['command-light-nightly', 'command-nightly'].includes(oai_settings.cohere_model)) {
+ $('#openai_max_context').attr('max', max_8k);
+ }
+ else if (['command-r'].includes(oai_settings.cohere_model)) {
+ $('#openai_max_context').attr('max', max_128k);
+ }
+ else {
+ $('#openai_max_context').attr('max', max_4k);
+ }
+ oai_settings.openai_max_context = Math.min(Number($('#openai_max_context').attr('max')), oai_settings.openai_max_context);
+ $('#openai_max_context').val(oai_settings.openai_max_context).trigger('input');
+ }
+
if (oai_settings.chat_completion_source == chat_completion_sources.AI21) {
if (oai_settings.max_context_unlocked) {
$('#openai_max_context').attr('max', unlocked_max);
@@ -3812,6 +3858,19 @@ async function onConnectButtonClick(e) {
}
}
+ if (oai_settings.chat_completion_source == chat_completion_sources.COHERE) {
+ const api_key_cohere = String($('#api_key_cohere').val()).trim();
+
+ if (api_key_cohere.length) {
+ await writeSecret(SECRET_KEYS.COHERE, api_key_cohere);
+ }
+
+ if (!secret_state[SECRET_KEYS.COHERE]) {
+ console.log('No secret key saved for Cohere');
+ return;
+ }
+ }
+
startStatusLoading();
saveSettingsDebounced();
await getStatusOpen();
@@ -3847,6 +3906,9 @@ function toggleChatCompletionForms() {
else if (oai_settings.chat_completion_source == chat_completion_sources.MISTRALAI) {
$('#model_mistralai_select').trigger('change');
}
+ else if (oai_settings.chat_completion_source == chat_completion_sources.COHERE) {
+ $('#model_cohere_select').trigger('change');
+ }
else if (oai_settings.chat_completion_source == chat_completion_sources.CUSTOM) {
$('#model_custom_select').trigger('change');
}
@@ -4499,6 +4561,7 @@ $(document).ready(async function () {
$('#openrouter_sort_models').on('change', onOpenrouterModelSortChange);
$('#model_ai21_select').on('change', onModelChange);
$('#model_mistralai_select').on('change', onModelChange);
+ $('#model_cohere_select').on('change', onModelChange);
$('#model_custom_select').on('change', onModelChange);
$('#settings_preset_openai').on('change', onSettingsPresetChange);
$('#new_oai_preset').on('click', onNewPresetClick);
diff --git a/public/scripts/secrets.js b/public/scripts/secrets.js
index a6d82e5e7..a6bed1057 100644
--- a/public/scripts/secrets.js
+++ b/public/scripts/secrets.js
@@ -23,6 +23,7 @@ export const SECRET_KEYS = {
NOMICAI: 'api_key_nomicai',
KOBOLDCPP: 'api_key_koboldcpp',
LLAMACPP: 'api_key_llamacpp',
+ COHERE: 'api_key_cohere',
};
const INPUT_MAP = {
@@ -47,6 +48,7 @@ const INPUT_MAP = {
[SECRET_KEYS.NOMICAI]: '#api_key_nomicai',
[SECRET_KEYS.KOBOLDCPP]: '#api_key_koboldcpp',
[SECRET_KEYS.LLAMACPP]: '#api_key_llamacpp',
+ [SECRET_KEYS.COHERE]: '#api_key_cohere',
};
async function clearSecret() {
diff --git a/public/scripts/slash-commands.js b/public/scripts/slash-commands.js
index aef1de058..70042dd3c 100644
--- a/public/scripts/slash-commands.js
+++ b/public/scripts/slash-commands.js
@@ -1660,6 +1660,7 @@ function modelCallback(_, model) {
{ id: 'model_google_select', api: 'openai', type: chat_completion_sources.MAKERSUITE },
{ id: 'model_mistralai_select', api: 'openai', type: chat_completion_sources.MISTRALAI },
{ id: 'model_custom_select', api: 'openai', type: chat_completion_sources.CUSTOM },
+ { id: 'model_cohere_select', api: 'openai', type: chat_completion_sources.COHERE },
{ id: 'model_novel_select', api: 'novel', type: null },
{ id: 'horde_model', api: 'koboldhorde', type: null },
];
diff --git a/src/constants.js b/src/constants.js
index db113a92c..918374eab 100644
--- a/src/constants.js
+++ b/src/constants.js
@@ -162,6 +162,7 @@ const CHAT_COMPLETION_SOURCES = {
MAKERSUITE: 'makersuite',
MISTRALAI: 'mistralai',
CUSTOM: 'custom',
+ COHERE: 'cohere',
};
const UPLOADS_PATH = './uploads';
diff --git a/src/endpoints/backends/chat-completions.js b/src/endpoints/backends/chat-completions.js
index c695e230a..8fe7cb6bf 100644
--- a/src/endpoints/backends/chat-completions.js
+++ b/src/endpoints/backends/chat-completions.js
@@ -1,10 +1,11 @@
const express = require('express');
const fetch = require('node-fetch').default;
+const Readable = require('stream').Readable;
const { jsonParser } = require('../../express-common');
const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY, OPENROUTER_HEADERS } = require('../../constants');
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4, mergeObjectWithYaml, excludeKeysByYaml, color } = require('../../util');
-const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt } = require('../../prompt-converters');
+const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages } = require('../../prompt-converters');
const { readSecret, SECRET_KEYS } = require('../secrets');
const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sentencepieceTokenizers, TEXT_COMPLETION_MODELS } = require('../tokenizers');
@@ -12,6 +13,61 @@ const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sente
const API_OPENAI = 'https://api.openai.com/v1';
const API_CLAUDE = 'https://api.anthropic.com/v1';
const API_MISTRAL = 'https://api.mistral.ai/v1';
+const API_COHERE = 'https://api.cohere.ai/v1';
+
+/**
+ * Ollama strikes back. Special boy #2's steaming routine.
+ * Wrap this abomination into proper SSE stream, again.
+ * @param {import('node-fetch').Response} jsonStream JSON stream
+ * @param {import('express').Request} request Express request
+ * @param {import('express').Response} response Express response
+ * @returns {Promise
} Nothing valuable
+ */
+async function parseCohereStream(jsonStream, request, response) {
+ try {
+ let partialData = '';
+ jsonStream.body.on('data', (data) => {
+ const chunk = data.toString();
+ partialData += chunk;
+ while (true) {
+ let json;
+ try {
+ json = JSON.parse(partialData);
+ } catch (e) {
+ break;
+ }
+ if (json.event_type === 'text-generation') {
+ const text = json.text || '';
+ const chunk = { choices: [{ text }] };
+ response.write(`data: ${JSON.stringify(chunk)}\n\n`);
+ partialData = '';
+ } else {
+ partialData = '';
+ break;
+ }
+ }
+ });
+
+ request.socket.on('close', function () {
+ if (jsonStream.body instanceof Readable) jsonStream.body.destroy();
+ response.end();
+ });
+
+ jsonStream.body.on('end', () => {
+ console.log('Streaming request finished');
+ response.write('data: [DONE]\n\n');
+ response.end();
+ });
+ } catch (error) {
+ console.log('Error forwarding streaming response:', error);
+ if (!response.headersSent) {
+ return response.status(500).send({ error: true });
+ } else {
+ return response.end();
+ }
+ }
+}
+
/**
* Sends a request to Claude API.
* @param {express.Request} request Express request
@@ -460,6 +516,85 @@ async function sendMistralAIRequest(request, response) {
}
}
+async function sendCohereRequest(request, response) {
+ const apiKey = readSecret(SECRET_KEYS.COHERE);
+ const controller = new AbortController();
+ request.socket.removeAllListeners('close');
+ request.socket.on('close', function () {
+ controller.abort();
+ });
+
+ if (!apiKey) {
+ console.log('Cohere API key is missing.');
+ return response.status(400).send({ error: true });
+ }
+
+ try {
+ const convertedHistory = convertCohereMessages(request.body.messages);
+
+ // https://docs.cohere.com/reference/chat
+ const requestBody = {
+ stream: Boolean(request.body.stream),
+ model: request.body.model,
+ message: convertedHistory.userPrompt,
+ preamble: convertedHistory.systemPrompt,
+ chat_history: convertedHistory.chatHistory,
+ temperature: request.body.temperature,
+ max_tokens: request.body.max_tokens,
+ k: request.body.top_k,
+ p: request.body.top_p,
+ seed: request.body.seed,
+ stop_sequences: request.body.stop,
+ frequency_penalty: request.body.frequency_penalty,
+ presence_penalty: request.body.presence_penalty,
+ prompt_truncation: 'AUTO_PRESERVE_ORDER',
+ connectors: [], // TODO
+ documents: [],
+ tools: [],
+ tool_results: [],
+ search_queries_only: false,
+ };
+
+ console.log('Cohere request:', requestBody);
+
+ const config = {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json',
+ 'Authorization': 'Bearer ' + apiKey,
+ },
+ body: JSON.stringify(requestBody),
+ signal: controller.signal,
+ timeout: 0,
+ };
+
+ const apiUrl = API_COHERE + '/chat';
+
+ if (request.body.stream) {
+ const stream = await fetch(apiUrl, config);
+ parseCohereStream(stream, request, response);
+ } else {
+ const generateResponse = await fetch(apiUrl, config);
+ if (!generateResponse.ok) {
+ console.log(`Cohere API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`);
+ // a 401 unauthorized response breaks the frontend auth, so return a 500 instead. prob a better way of dealing with this.
+ // 401s are already handled by the streaming processor and dont pop up an error toast, that should probably be fixed too.
+ return response.status(generateResponse.status === 401 ? 500 : generateResponse.status).send({ error: true });
+ }
+ const generateResponseJson = await generateResponse.json();
+ console.log('Cohere response:', generateResponseJson);
+ return response.send(generateResponseJson);
+ }
+ } catch (error) {
+ console.log('Error communicating with Cohere API: ', error);
+ if (!response.headersSent) {
+ response.send({ error: true });
+ } else {
+ response.end();
+ }
+ }
+}
+
const router = express.Router();
router.post('/status', jsonParser, async function (request, response_getstatus_openai) {
@@ -487,6 +622,10 @@ router.post('/status', jsonParser, async function (request, response_getstatus_o
api_key_openai = readSecret(SECRET_KEYS.CUSTOM);
headers = {};
mergeObjectWithYaml(headers, request.body.custom_include_headers);
+ } else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.COHERE) {
+ api_url = API_COHERE;
+ api_key_openai = readSecret(SECRET_KEYS.COHERE);
+ headers = {};
} else {
console.log('This chat completion source is not supported yet.');
return response_getstatus_openai.status(400).send({ error: true });
@@ -510,6 +649,10 @@ router.post('/status', jsonParser, async function (request, response_getstatus_o
const data = await response.json();
response_getstatus_openai.send(data);
+ if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.COHERE && Array.isArray(data?.models)) {
+ data.data = data.models.map(model => ({ id: model.name, ...model }));
+ }
+
if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.OPENROUTER && Array.isArray(data?.data)) {
let models = [];
@@ -635,6 +778,7 @@ router.post('/generate', jsonParser, function (request, response) {
case CHAT_COMPLETION_SOURCES.AI21: return sendAI21Request(request, response);
case CHAT_COMPLETION_SOURCES.MAKERSUITE: return sendMakerSuiteRequest(request, response);
case CHAT_COMPLETION_SOURCES.MISTRALAI: return sendMistralAIRequest(request, response);
+ case CHAT_COMPLETION_SOURCES.COHERE: return sendCohereRequest(request, response);
}
let apiUrl;
diff --git a/src/endpoints/secrets.js b/src/endpoints/secrets.js
index 55c5df008..afd41a1f7 100644
--- a/src/endpoints/secrets.js
+++ b/src/endpoints/secrets.js
@@ -35,6 +35,7 @@ const SECRET_KEYS = {
NOMICAI: 'api_key_nomicai',
KOBOLDCPP: 'api_key_koboldcpp',
LLAMACPP: 'api_key_llamacpp',
+ COHERE: 'api_key_cohere',
};
// These are the keys that are safe to expose, even if allowKeysExposure is false
diff --git a/src/polyfill.js b/src/polyfill.js
new file mode 100644
index 000000000..7bed18a1f
--- /dev/null
+++ b/src/polyfill.js
@@ -0,0 +1,8 @@
+if (!Array.prototype.findLastIndex) {
+ Array.prototype.findLastIndex = function (callback, thisArg) {
+ for (let i = this.length - 1; i >= 0; i--) {
+ if (callback.call(thisArg, this[i], i, this)) return i;
+ }
+ return -1;
+ };
+}
diff --git a/src/prompt-converters.js b/src/prompt-converters.js
index 42f7abaf7..72b75e223 100644
--- a/src/prompt-converters.js
+++ b/src/prompt-converters.js
@@ -1,3 +1,5 @@
+require('./polyfill.js');
+
/**
* Convert a prompt from the ChatML objects to the format used by Claude.
* @param {object[]} messages Array of messages
@@ -188,6 +190,64 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFi
return { messages: mergedMessages, systemPrompt: systemPrompt.trim() };
}
+/**
+ * Convert a prompt from the ChatML objects to the format used by Cohere.
+ * @param {object[]} messages Array of messages
+ * @param {string} charName Character name
+ * @param {string} userName User name
+ * @returns {{systemPrompt: string, chatHistory: object[], userPrompt: string}} Prompt for Cohere
+ */
+function convertCohereMessages(messages, charName = '', userName = '') {
+ const roleMap = {
+ 'system': 'SYSTEM',
+ 'user': 'USER',
+ 'assistant': 'CHATBOT',
+ };
+ const placeholder = '[Start a new chat]';
+ let systemPrompt = '';
+
+ // Collect all the system messages up until the first instance of a non-system message, and then remove them from the messages array.
+ let i;
+ for (i = 0; i < messages.length; i++) {
+ if (messages[i].role !== 'system') {
+ break;
+ }
+ // Append example names if not already done by the frontend (e.g. for group chats).
+ if (userName && messages[i].name === 'example_user') {
+ if (!messages[i].content.startsWith(`${userName}: `)) {
+ messages[i].content = `${userName}: ${messages[i].content}`;
+ }
+ }
+ if (charName && messages[i].name === 'example_assistant') {
+ if (!messages[i].content.startsWith(`${charName}: `)) {
+ messages[i].content = `${charName}: ${messages[i].content}`;
+ }
+ }
+ systemPrompt += `${messages[i].content}\n\n`;
+ }
+
+ messages.splice(0, i);
+
+ if (messages.length === 0) {
+ messages.unshift({
+ role: 'user',
+ content: placeholder,
+ });
+ }
+
+ const lastNonSystemMessageIndex = messages.findLastIndex(msg => msg.role === 'user' || msg.role === 'assistant');
+ const userPrompt = messages.slice(lastNonSystemMessageIndex).map(msg => msg.content).join('\n\n') || placeholder;
+
+ const chatHistory = messages.slice(0, lastNonSystemMessageIndex).map(msg => {
+ return {
+ role: roleMap[msg.role] || 'USER',
+ message: msg.content,
+ };
+ });
+
+ return { systemPrompt: systemPrompt.trim(), chatHistory, userPrompt };
+}
+
/**
* Convert a prompt from the ChatML objects to the format used by Google MakerSuite models.
* @param {object[]} messages Array of messages
@@ -300,4 +360,5 @@ module.exports = {
convertClaudeMessages,
convertGooglePrompt,
convertTextCompletionPrompt,
+ convertCohereMessages,
};