#1994 Add Cohere as a Chat Completion source

This commit is contained in:
Cohee
2024-04-02 00:20:17 +03:00
parent 9c6d8e6895
commit 9838ba8044
12 changed files with 347 additions and 19 deletions

View File

@ -1,3 +1,5 @@
require('./polyfill.js');
/**
* Convert a prompt from the ChatML objects to the format used by Claude.
* @param {object[]} messages Array of messages
@ -188,6 +190,64 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFi
return { messages: mergedMessages, systemPrompt: systemPrompt.trim() };
}
/**
* Convert a prompt from the ChatML objects to the format used by Cohere.
* @param {object[]} messages Array of messages
* @param {string} charName Character name
* @param {string} userName User name
* @returns {{systemPrompt: string, chatHistory: object[], userPrompt: string}} Prompt for Cohere
*/
function convertCohereMessages(messages, charName = '', userName = '') {
const roleMap = {
'system': 'SYSTEM',
'user': 'USER',
'assistant': 'CHATBOT',
};
const placeholder = '[Start a new chat]';
let systemPrompt = '';
// Collect all the system messages up until the first instance of a non-system message, and then remove them from the messages array.
let i;
for (i = 0; i < messages.length; i++) {
if (messages[i].role !== 'system') {
break;
}
// Append example names if not already done by the frontend (e.g. for group chats).
if (userName && messages[i].name === 'example_user') {
if (!messages[i].content.startsWith(`${userName}: `)) {
messages[i].content = `${userName}: ${messages[i].content}`;
}
}
if (charName && messages[i].name === 'example_assistant') {
if (!messages[i].content.startsWith(`${charName}: `)) {
messages[i].content = `${charName}: ${messages[i].content}`;
}
}
systemPrompt += `${messages[i].content}\n\n`;
}
messages.splice(0, i);
if (messages.length === 0) {
messages.unshift({
role: 'user',
content: placeholder,
});
}
const lastNonSystemMessageIndex = messages.findLastIndex(msg => msg.role === 'user' || msg.role === 'assistant');
const userPrompt = messages.slice(lastNonSystemMessageIndex).map(msg => msg.content).join('\n\n') || placeholder;
const chatHistory = messages.slice(0, lastNonSystemMessageIndex).map(msg => {
return {
role: roleMap[msg.role] || 'USER',
message: msg.content,
};
});
return { systemPrompt: systemPrompt.trim(), chatHistory, userPrompt };
}
/**
* Convert a prompt from the ChatML objects to the format used by Google MakerSuite models.
* @param {object[]} messages Array of messages
@ -300,4 +360,5 @@ module.exports = {
convertClaudeMessages,
convertGooglePrompt,
convertTextCompletionPrompt,
convertCohereMessages,
};