#1994 Add Cohere as a Chat Completion source

This commit is contained in:
Cohee 2024-04-02 00:20:17 +03:00
parent 9c6d8e6895
commit 9838ba8044
12 changed files with 347 additions and 19 deletions

12
public/img/cohere.svg Normal file
View File

@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<!-- Created with Inkscape (http://www.inkscape.org/) -->
<svg width="47.403999mm" height="47.58918mm" viewBox="0 0 47.403999 47.58918" version="1.1" id="svg1" xml:space="preserve" inkscape:version="1.3 (0e150ed, 2023-07-21)" sodipodi:docname="cohere.svg"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg">
<sodipodi:namedview id="namedview1" pagecolor="#ffffff" bordercolor="#000000" borderopacity="0.25" inkscape:showpageshadow="2" inkscape:pageopacity="0.0" inkscape:pagecheckerboard="0" inkscape:deskcolor="#d1d1d1" inkscape:document-units="mm" inkscape:clip-to-page="false" inkscape:zoom="0.69294747" inkscape:cx="67.826209" inkscape:cy="74.320208" inkscape:window-width="1280" inkscape:window-height="688" inkscape:window-x="0" inkscape:window-y="25" inkscape:window-maximized="1" inkscape:current-layer="svg1" />
<defs id="defs1" />
<path id="path7" fill="currentColor" d="m 88.320761,61.142067 c -5.517973,0.07781 -11.05887,-0.197869 -16.558458,0.321489 -6.843243,0.616907 -12.325958,7.018579 -12.29857,13.807832 -0.139102,5.883715 3.981307,11.431418 9.578012,13.180923 3.171819,1.100505 6.625578,1.228214 9.855341,0.291715 3.455286,-0.847586 6.634981,-2.530123 9.969836,-3.746213 4.659947,-1.981154 9.49864,-3.782982 13.612498,-6.795254 3.80146,-2.664209 4.45489,-8.316688 2.00772,-12.1054 -1.74871,-3.034851 -5.172793,-4.896444 -8.663697,-4.741041 -2.49833,-0.140901 -5.000698,-0.196421 -7.502682,-0.214051 z m 7.533907,25.636161 c -3.334456,0.15056 -6.379399,1.79356 -9.409724,3.054098 -2.379329,1.032102 -4.911953,2.154839 -6.246333,4.528375 -2.118159,3.080424 -2.02565,7.404239 0.309716,10.346199 1.877703,2.72985 5.192756,4.03199 8.428778,3.95319 3.087361,0.0764 6.223907,0.19023 9.275119,-0.34329 5.816976,-1.32118 9.855546,-7.83031 8.101436,-13.600351 -1.30234,-4.509858 -5.762,-7.905229 -10.458992,-7.938221 z m -28.342456,4.770768 c -4.357593,-0.129828 -8.148265,3.780554 -8.168711,8.09095 -0.296313,4.101314 2.711752,8.289544 6.873869,8.869074 4.230007,0.80322 8.929483,-2.66416 9.017046,-7.07348 0.213405,-2.445397 0.09191,-5.152074 -1.705492,-7.039611 -1.484313,-1.763448 -3.717801,-2.798154 -6.016712,-2.846933 z" transform="translate(-59.323375,-61.136763)" />
</svg>

After

Width:  |  Height:  |  Size: 2.4 KiB

View File

@ -458,7 +458,7 @@
</span>
</div>
</div>
<div class="range-block" data-source="openai,claude,windowai,openrouter,ai21,scale,makersuite,mistralai,custom">
<div class="range-block" data-source="openai,claude,windowai,openrouter,ai21,scale,makersuite,mistralai,custom,cohere">
<div class="range-block-title" data-i18n="Temperature">
Temperature
</div>
@ -471,7 +471,7 @@
</div>
</div>
</div>
<div data-newbie-hidden class="range-block" data-source="openai,openrouter,ai21,custom">
<div data-newbie-hidden class="range-block" data-source="openai,openrouter,ai21,custom,cohere">
<div class="range-block-title" data-i18n="Frequency Penalty">
Frequency Penalty
</div>
@ -484,7 +484,7 @@
</div>
</div>
</div>
<div data-newbie-hidden class="range-block" data-source="openai,openrouter,ai21,custom">
<div data-newbie-hidden class="range-block" data-source="openai,openrouter,ai21,custom,cohere">
<div class="range-block-title" data-i18n="Presence Penalty">
Presence Penalty
</div>
@ -510,20 +510,20 @@
</div>
</div>
</div>
<div data-newbie-hidden class="range-block" data-source="claude,openrouter,ai21,makersuite">
<div data-newbie-hidden class="range-block" data-source="claude,openrouter,ai21,makersuite,cohere">
<div class="range-block-title" data-i18n="Top K">
Top K
</div>
<div class="range-block-range-and-counter">
<div class="range-block-range">
<input type="range" id="top_k_openai" name="volume" min="0" max="200" step="1">
<input type="range" id="top_k_openai" name="volume" min="0" max="500" step="1">
</div>
<div class="range-block-counter">
<input type="number" min="0" max="200" step="1" data-for="top_k_openai" id="top_k_counter_openai">
</div>
</div>
</div>
<div data-newbie-hidden class="range-block" data-source="openai,claude,openrouter,ai21,scale,makersuite,mistralai,custom">
<div data-newbie-hidden class="range-block" data-source="openai,claude,openrouter,ai21,scale,makersuite,mistralai,custom,cohere">
<div class="range-block-title" data-i18n="Top-p">
Top P
</div>
@ -759,7 +759,7 @@
</div>
</div>
</div>
<div data-newbie-hidden class="range-block" data-source="openai,openrouter,mistralai,custom">
<div data-newbie-hidden class="range-block" data-source="openai,openrouter,mistralai,custom,cohere">
<div class="range-block-title justifyLeft" data-i18n="Seed">
Seed
</div>
@ -2259,15 +2259,20 @@
Chat Completion Source
</h4>
<select id="chat_completion_source">
<option value="openai">OpenAI</option>
<option value="windowai">Window AI</option>
<option value="openrouter">OpenRouter</option>
<option value="claude">Claude</option>
<option value="scale">Scale</option>
<option value="ai21">AI21</option>
<option value="makersuite">Google MakerSuite</option>
<option value="mistralai">MistralAI</option>
<option value="custom">Custom (OpenAI-compatible)</option>
<optgroup>
<option value="openai">OpenAI</option>
<option value="custom">Custom (OpenAI-compatible)</option>
</optgroup>
<optgroup>
<option value="ai21">AI21</option>
<option value="claude">Claude</option>
<option value="cohere">Cohere</option>
<option value="makersuite">Google MakerSuite</option>
<option value="mistralai">MistralAI</option>
<option value="openrouter">OpenRouter</option>
<option value="scale">Scale</option>
<option value="windowai">Window AI</option>
</optgroup>
</select>
<div data-newbie-hidden class="inline-drawer wide100p" data-source="openai,claude,mistralai">
<div class="inline-drawer-toggle inline-drawer-header">
@ -2659,6 +2664,30 @@
</select>
</div>
</form>
<form id="cohere_form" data-source="cohere" action="javascript:void(null);" method="post" enctype="multipart/form-data">
<h4 data-i18n="Cohere API Key">Cohere API Key</h4>
<div class="flex-container">
<input id="api_key_cohere" name="api_key_cohere" class="text_pole flex1" maxlength="500" value="" type="text" autocomplete="off">
<div title="Clear your API key" data-i18n="[title]Clear your API key" class="menu_button fa-solid fa-circle-xmark clear-api-key" data-key="api_key_cohere"></div>
</div>
<div data-for="api_key_cohere" class="neutral_warning">
For privacy reasons, your API key will be hidden after you reload the page.
</div>
<div>
<h4 data-i18n="Cohere Model">Cohere Model</h4>
<select id="model_cohere_select">
<optgroup label="Stable">
<option value="command-light">command-light</option>
<option value="command">command</option>
<option value="command-r">command-r</option>
</optgroup>
<optgroup label="Nightly">
<option value="command-light-nightly">command-light-nightly</option>
<option value="command-nightly">command-nightly</option>
</optgroup>
</select>
</div>
</form>
<form id="custom_form" data-source="custom">
<h4 data-i18n="Custom Endpoint (Base URL)">Custom Endpoint (Base URL)</h4>
<div class="flex-container">

View File

@ -4836,7 +4836,7 @@ function extractMessageFromData(data) {
case 'novel':
return data.output;
case 'openai':
return data?.choices?.[0]?.message?.content ?? data?.choices?.[0]?.text ?? '';
return data?.choices?.[0]?.message?.content ?? data?.choices?.[0]?.text ?? data?.text ?? '';
default:
return '';
}
@ -8187,6 +8187,11 @@ const CONNECT_API_MAP = {
button: '#api_button_openai',
source: chat_completion_sources.CUSTOM,
},
'cohere': {
selected: 'cohere',
button: '#api_button_openai',
source: chat_completion_sources.COHERE,
},
'infermaticai': {
selected: 'textgenerationwebui',
button: '#api_button_textgenerationwebui',

View File

@ -350,6 +350,7 @@ function RA_autoconnect(PrevApi) {
|| (secret_state[SECRET_KEYS.AI21] && oai_settings.chat_completion_source == chat_completion_sources.AI21)
|| (secret_state[SECRET_KEYS.MAKERSUITE] && oai_settings.chat_completion_source == chat_completion_sources.MAKERSUITE)
|| (secret_state[SECRET_KEYS.MISTRALAI] && oai_settings.chat_completion_source == chat_completion_sources.MISTRALAI)
|| (secret_state[SECRET_KEYS.COHERE] && oai_settings.chat_completion_source == chat_completion_sources.COHERE)
|| (isValidUrl(oai_settings.custom_url) && oai_settings.chat_completion_source == chat_completion_sources.CUSTOM)
) {
$('#api_button_openai').trigger('click');

View File

@ -171,6 +171,7 @@ export const chat_completion_sources = {
MAKERSUITE: 'makersuite',
MISTRALAI: 'mistralai',
CUSTOM: 'custom',
COHERE: 'cohere',
};
const character_names_behavior = {
@ -230,6 +231,7 @@ const default_settings = {
google_model: 'gemini-pro',
ai21_model: 'j2-ultra',
mistralai_model: 'mistral-medium-latest',
cohere_model: 'command-r',
custom_model: '',
custom_url: '',
custom_include_body: '',
@ -298,6 +300,7 @@ const oai_settings = {
google_model: 'gemini-pro',
ai21_model: 'j2-ultra',
mistralai_model: 'mistral-medium-latest',
cohere_model: 'command-r',
custom_model: '',
custom_url: '',
custom_include_body: '',
@ -1384,6 +1387,8 @@ function getChatCompletionModel() {
return oai_settings.mistralai_model;
case chat_completion_sources.CUSTOM:
return oai_settings.custom_model;
case chat_completion_sources.COHERE:
return oai_settings.cohere_model;
default:
throw new Error(`Unknown chat completion source: ${oai_settings.chat_completion_source}`);
}
@ -1603,6 +1608,7 @@ async function sendOpenAIRequest(type, messages, signal) {
const isOAI = oai_settings.chat_completion_source == chat_completion_sources.OPENAI;
const isMistral = oai_settings.chat_completion_source == chat_completion_sources.MISTRALAI;
const isCustom = oai_settings.chat_completion_source == chat_completion_sources.CUSTOM;
const isCohere = oai_settings.chat_completion_source == chat_completion_sources.COHERE;
const isTextCompletion = (isOAI && textCompletionModels.includes(oai_settings.openai_model)) || (isOpenRouter && oai_settings.openrouter_force_instruct && power_user.instruct.enabled);
const isQuiet = type === 'quiet';
const isImpersonate = type === 'impersonate';
@ -1737,7 +1743,17 @@ async function sendOpenAIRequest(type, messages, signal) {
generate_data['custom_include_headers'] = oai_settings.custom_include_headers;
}
if ((isOAI || isOpenRouter || isMistral || isCustom) && oai_settings.seed >= 0) {
if (isCohere) {
// Clamp to 0.01 -> 0.99
generate_data['top_p'] = Math.min(Math.max(Number(oai_settings.top_p_openai), 0.01), 0.99);
generate_data['top_k'] = Number(oai_settings.top_k_openai);
// Clamp to 0 -> 1
generate_data['frequency_penalty'] = Math.min(Math.max(Number(oai_settings.freq_pen_openai), 0), 1);
generate_data['presence_penalty'] = Math.min(Math.max(Number(oai_settings.pres_pen_openai), 0), 1);
generate_data['stop'] = getCustomStoppingStrings(5);
}
if ((isOAI || isOpenRouter || isMistral || isCustom || isCohere) && oai_settings.seed >= 0) {
generate_data['seed'] = oai_settings.seed;
}
@ -2597,6 +2613,7 @@ function loadOpenAISettings(data, settings) {
oai_settings.openrouter_force_instruct = settings.openrouter_force_instruct ?? default_settings.openrouter_force_instruct;
oai_settings.ai21_model = settings.ai21_model ?? default_settings.ai21_model;
oai_settings.mistralai_model = settings.mistralai_model ?? default_settings.mistralai_model;
oai_settings.cohere_model = settings.cohere_model ?? default_settings.cohere_model;
oai_settings.custom_model = settings.custom_model ?? default_settings.custom_model;
oai_settings.custom_url = settings.custom_url ?? default_settings.custom_url;
oai_settings.custom_include_body = settings.custom_include_body ?? default_settings.custom_include_body;
@ -2657,6 +2674,8 @@ function loadOpenAISettings(data, settings) {
$(`#model_ai21_select option[value="${oai_settings.ai21_model}"`).attr('selected', true);
$('#model_mistralai_select').val(oai_settings.mistralai_model);
$(`#model_mistralai_select option[value="${oai_settings.mistralai_model}"`).attr('selected', true);
$('#model_cohere_select').val(oai_settings.cohere_model);
$(`#model_cohere_select option[value="${oai_settings.cohere_model}"`).attr('selected', true);
$('#custom_model_id').val(oai_settings.custom_model);
$('#custom_api_url_text').val(oai_settings.custom_url);
$('#openai_max_context').val(oai_settings.openai_max_context);
@ -2893,6 +2912,7 @@ async function saveOpenAIPreset(name, settings, triggerUi = true) {
openrouter_sort_models: settings.openrouter_sort_models,
ai21_model: settings.ai21_model,
mistralai_model: settings.mistralai_model,
cohere_model: settings.cohere_model,
custom_model: settings.custom_model,
custom_url: settings.custom_url,
custom_include_body: settings.custom_include_body,
@ -3281,6 +3301,7 @@ function onSettingsPresetChange() {
openrouter_sort_models: ['#openrouter_sort_models', 'openrouter_sort_models', false],
ai21_model: ['#model_ai21_select', 'ai21_model', false],
mistralai_model: ['#model_mistralai_select', 'mistralai_model', false],
cohere_model: ['#model_cohere_select', 'cohere_model', false],
custom_model: ['#custom_model_id', 'custom_model', false],
custom_url: ['#custom_api_url_text', 'custom_url', false],
custom_include_body: ['#custom_include_body', 'custom_include_body', false],
@ -3496,6 +3517,11 @@ async function onModelChange() {
$('#model_mistralai_select').val(oai_settings.mistralai_model);
}
if ($(this).is('#model_cohere_select')) {
console.log('Cohere model changed to', value);
oai_settings.cohere_model = value;
}
if (value && $(this).is('#model_custom_select')) {
console.log('Custom model changed to', value);
oai_settings.custom_model = value;
@ -3619,6 +3645,26 @@ async function onModelChange() {
$('#temp_openai').attr('max', claude_max_temp).val(oai_settings.temp_openai).trigger('input');
}
if (oai_settings.chat_completion_source === chat_completion_sources.COHERE) {
if (oai_settings.max_context_unlocked) {
$('#openai_max_context').attr('max', unlocked_max);
}
else if (['command-light', 'command'].includes(oai_settings.cohere_model)) {
$('#openai_max_context').attr('max', max_4k);
}
else if (['command-light-nightly', 'command-nightly'].includes(oai_settings.cohere_model)) {
$('#openai_max_context').attr('max', max_8k);
}
else if (['command-r'].includes(oai_settings.cohere_model)) {
$('#openai_max_context').attr('max', max_128k);
}
else {
$('#openai_max_context').attr('max', max_4k);
}
oai_settings.openai_max_context = Math.min(Number($('#openai_max_context').attr('max')), oai_settings.openai_max_context);
$('#openai_max_context').val(oai_settings.openai_max_context).trigger('input');
}
if (oai_settings.chat_completion_source == chat_completion_sources.AI21) {
if (oai_settings.max_context_unlocked) {
$('#openai_max_context').attr('max', unlocked_max);
@ -3812,6 +3858,19 @@ async function onConnectButtonClick(e) {
}
}
if (oai_settings.chat_completion_source == chat_completion_sources.COHERE) {
const api_key_cohere = String($('#api_key_cohere').val()).trim();
if (api_key_cohere.length) {
await writeSecret(SECRET_KEYS.COHERE, api_key_cohere);
}
if (!secret_state[SECRET_KEYS.COHERE]) {
console.log('No secret key saved for Cohere');
return;
}
}
startStatusLoading();
saveSettingsDebounced();
await getStatusOpen();
@ -3847,6 +3906,9 @@ function toggleChatCompletionForms() {
else if (oai_settings.chat_completion_source == chat_completion_sources.MISTRALAI) {
$('#model_mistralai_select').trigger('change');
}
else if (oai_settings.chat_completion_source == chat_completion_sources.COHERE) {
$('#model_cohere_select').trigger('change');
}
else if (oai_settings.chat_completion_source == chat_completion_sources.CUSTOM) {
$('#model_custom_select').trigger('change');
}
@ -4499,6 +4561,7 @@ $(document).ready(async function () {
$('#openrouter_sort_models').on('change', onOpenrouterModelSortChange);
$('#model_ai21_select').on('change', onModelChange);
$('#model_mistralai_select').on('change', onModelChange);
$('#model_cohere_select').on('change', onModelChange);
$('#model_custom_select').on('change', onModelChange);
$('#settings_preset_openai').on('change', onSettingsPresetChange);
$('#new_oai_preset').on('click', onNewPresetClick);

View File

@ -23,6 +23,7 @@ export const SECRET_KEYS = {
NOMICAI: 'api_key_nomicai',
KOBOLDCPP: 'api_key_koboldcpp',
LLAMACPP: 'api_key_llamacpp',
COHERE: 'api_key_cohere',
};
const INPUT_MAP = {
@ -47,6 +48,7 @@ const INPUT_MAP = {
[SECRET_KEYS.NOMICAI]: '#api_key_nomicai',
[SECRET_KEYS.KOBOLDCPP]: '#api_key_koboldcpp',
[SECRET_KEYS.LLAMACPP]: '#api_key_llamacpp',
[SECRET_KEYS.COHERE]: '#api_key_cohere',
};
async function clearSecret() {

View File

@ -1660,6 +1660,7 @@ function modelCallback(_, model) {
{ id: 'model_google_select', api: 'openai', type: chat_completion_sources.MAKERSUITE },
{ id: 'model_mistralai_select', api: 'openai', type: chat_completion_sources.MISTRALAI },
{ id: 'model_custom_select', api: 'openai', type: chat_completion_sources.CUSTOM },
{ id: 'model_cohere_select', api: 'openai', type: chat_completion_sources.COHERE },
{ id: 'model_novel_select', api: 'novel', type: null },
{ id: 'horde_model', api: 'koboldhorde', type: null },
];

View File

@ -162,6 +162,7 @@ const CHAT_COMPLETION_SOURCES = {
MAKERSUITE: 'makersuite',
MISTRALAI: 'mistralai',
CUSTOM: 'custom',
COHERE: 'cohere',
};
const UPLOADS_PATH = './uploads';

View File

@ -1,10 +1,11 @@
const express = require('express');
const fetch = require('node-fetch').default;
const Readable = require('stream').Readable;
const { jsonParser } = require('../../express-common');
const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY, OPENROUTER_HEADERS } = require('../../constants');
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4, mergeObjectWithYaml, excludeKeysByYaml, color } = require('../../util');
const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt } = require('../../prompt-converters');
const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages } = require('../../prompt-converters');
const { readSecret, SECRET_KEYS } = require('../secrets');
const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sentencepieceTokenizers, TEXT_COMPLETION_MODELS } = require('../tokenizers');
@ -12,6 +13,61 @@ const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sente
const API_OPENAI = 'https://api.openai.com/v1';
const API_CLAUDE = 'https://api.anthropic.com/v1';
const API_MISTRAL = 'https://api.mistral.ai/v1';
const API_COHERE = 'https://api.cohere.ai/v1';
/**
* Ollama strikes back. Special boy #2's steaming routine.
* Wrap this abomination into proper SSE stream, again.
* @param {import('node-fetch').Response} jsonStream JSON stream
* @param {import('express').Request} request Express request
* @param {import('express').Response} response Express response
* @returns {Promise<any>} Nothing valuable
*/
async function parseCohereStream(jsonStream, request, response) {
try {
let partialData = '';
jsonStream.body.on('data', (data) => {
const chunk = data.toString();
partialData += chunk;
while (true) {
let json;
try {
json = JSON.parse(partialData);
} catch (e) {
break;
}
if (json.event_type === 'text-generation') {
const text = json.text || '';
const chunk = { choices: [{ text }] };
response.write(`data: ${JSON.stringify(chunk)}\n\n`);
partialData = '';
} else {
partialData = '';
break;
}
}
});
request.socket.on('close', function () {
if (jsonStream.body instanceof Readable) jsonStream.body.destroy();
response.end();
});
jsonStream.body.on('end', () => {
console.log('Streaming request finished');
response.write('data: [DONE]\n\n');
response.end();
});
} catch (error) {
console.log('Error forwarding streaming response:', error);
if (!response.headersSent) {
return response.status(500).send({ error: true });
} else {
return response.end();
}
}
}
/**
* Sends a request to Claude API.
* @param {express.Request} request Express request
@ -460,6 +516,85 @@ async function sendMistralAIRequest(request, response) {
}
}
async function sendCohereRequest(request, response) {
const apiKey = readSecret(SECRET_KEYS.COHERE);
const controller = new AbortController();
request.socket.removeAllListeners('close');
request.socket.on('close', function () {
controller.abort();
});
if (!apiKey) {
console.log('Cohere API key is missing.');
return response.status(400).send({ error: true });
}
try {
const convertedHistory = convertCohereMessages(request.body.messages);
// https://docs.cohere.com/reference/chat
const requestBody = {
stream: Boolean(request.body.stream),
model: request.body.model,
message: convertedHistory.userPrompt,
preamble: convertedHistory.systemPrompt,
chat_history: convertedHistory.chatHistory,
temperature: request.body.temperature,
max_tokens: request.body.max_tokens,
k: request.body.top_k,
p: request.body.top_p,
seed: request.body.seed,
stop_sequences: request.body.stop,
frequency_penalty: request.body.frequency_penalty,
presence_penalty: request.body.presence_penalty,
prompt_truncation: 'AUTO_PRESERVE_ORDER',
connectors: [], // TODO
documents: [],
tools: [],
tool_results: [],
search_queries_only: false,
};
console.log('Cohere request:', requestBody);
const config = {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + apiKey,
},
body: JSON.stringify(requestBody),
signal: controller.signal,
timeout: 0,
};
const apiUrl = API_COHERE + '/chat';
if (request.body.stream) {
const stream = await fetch(apiUrl, config);
parseCohereStream(stream, request, response);
} else {
const generateResponse = await fetch(apiUrl, config);
if (!generateResponse.ok) {
console.log(`Cohere API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`);
// a 401 unauthorized response breaks the frontend auth, so return a 500 instead. prob a better way of dealing with this.
// 401s are already handled by the streaming processor and dont pop up an error toast, that should probably be fixed too.
return response.status(generateResponse.status === 401 ? 500 : generateResponse.status).send({ error: true });
}
const generateResponseJson = await generateResponse.json();
console.log('Cohere response:', generateResponseJson);
return response.send(generateResponseJson);
}
} catch (error) {
console.log('Error communicating with Cohere API: ', error);
if (!response.headersSent) {
response.send({ error: true });
} else {
response.end();
}
}
}
const router = express.Router();
router.post('/status', jsonParser, async function (request, response_getstatus_openai) {
@ -487,6 +622,10 @@ router.post('/status', jsonParser, async function (request, response_getstatus_o
api_key_openai = readSecret(SECRET_KEYS.CUSTOM);
headers = {};
mergeObjectWithYaml(headers, request.body.custom_include_headers);
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.COHERE) {
api_url = API_COHERE;
api_key_openai = readSecret(SECRET_KEYS.COHERE);
headers = {};
} else {
console.log('This chat completion source is not supported yet.');
return response_getstatus_openai.status(400).send({ error: true });
@ -510,6 +649,10 @@ router.post('/status', jsonParser, async function (request, response_getstatus_o
const data = await response.json();
response_getstatus_openai.send(data);
if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.COHERE && Array.isArray(data?.models)) {
data.data = data.models.map(model => ({ id: model.name, ...model }));
}
if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.OPENROUTER && Array.isArray(data?.data)) {
let models = [];
@ -635,6 +778,7 @@ router.post('/generate', jsonParser, function (request, response) {
case CHAT_COMPLETION_SOURCES.AI21: return sendAI21Request(request, response);
case CHAT_COMPLETION_SOURCES.MAKERSUITE: return sendMakerSuiteRequest(request, response);
case CHAT_COMPLETION_SOURCES.MISTRALAI: return sendMistralAIRequest(request, response);
case CHAT_COMPLETION_SOURCES.COHERE: return sendCohereRequest(request, response);
}
let apiUrl;

View File

@ -35,6 +35,7 @@ const SECRET_KEYS = {
NOMICAI: 'api_key_nomicai',
KOBOLDCPP: 'api_key_koboldcpp',
LLAMACPP: 'api_key_llamacpp',
COHERE: 'api_key_cohere',
};
// These are the keys that are safe to expose, even if allowKeysExposure is false

8
src/polyfill.js Normal file
View File

@ -0,0 +1,8 @@
if (!Array.prototype.findLastIndex) {
Array.prototype.findLastIndex = function (callback, thisArg) {
for (let i = this.length - 1; i >= 0; i--) {
if (callback.call(thisArg, this[i], i, this)) return i;
}
return -1;
};
}

View File

@ -1,3 +1,5 @@
require('./polyfill.js');
/**
* Convert a prompt from the ChatML objects to the format used by Claude.
* @param {object[]} messages Array of messages
@ -188,6 +190,64 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFi
return { messages: mergedMessages, systemPrompt: systemPrompt.trim() };
}
/**
* Convert a prompt from the ChatML objects to the format used by Cohere.
* @param {object[]} messages Array of messages
* @param {string} charName Character name
* @param {string} userName User name
* @returns {{systemPrompt: string, chatHistory: object[], userPrompt: string}} Prompt for Cohere
*/
function convertCohereMessages(messages, charName = '', userName = '') {
const roleMap = {
'system': 'SYSTEM',
'user': 'USER',
'assistant': 'CHATBOT',
};
const placeholder = '[Start a new chat]';
let systemPrompt = '';
// Collect all the system messages up until the first instance of a non-system message, and then remove them from the messages array.
let i;
for (i = 0; i < messages.length; i++) {
if (messages[i].role !== 'system') {
break;
}
// Append example names if not already done by the frontend (e.g. for group chats).
if (userName && messages[i].name === 'example_user') {
if (!messages[i].content.startsWith(`${userName}: `)) {
messages[i].content = `${userName}: ${messages[i].content}`;
}
}
if (charName && messages[i].name === 'example_assistant') {
if (!messages[i].content.startsWith(`${charName}: `)) {
messages[i].content = `${charName}: ${messages[i].content}`;
}
}
systemPrompt += `${messages[i].content}\n\n`;
}
messages.splice(0, i);
if (messages.length === 0) {
messages.unshift({
role: 'user',
content: placeholder,
});
}
const lastNonSystemMessageIndex = messages.findLastIndex(msg => msg.role === 'user' || msg.role === 'assistant');
const userPrompt = messages.slice(lastNonSystemMessageIndex).map(msg => msg.content).join('\n\n') || placeholder;
const chatHistory = messages.slice(0, lastNonSystemMessageIndex).map(msg => {
return {
role: roleMap[msg.role] || 'USER',
message: msg.content,
};
});
return { systemPrompt: systemPrompt.trim(), chatHistory, userPrompt };
}
/**
* Convert a prompt from the ChatML objects to the format used by Google MakerSuite models.
* @param {object[]} messages Array of messages
@ -300,4 +360,5 @@ module.exports = {
convertClaudeMessages,
convertGooglePrompt,
convertTextCompletionPrompt,
convertCohereMessages,
};