Merge branch 'staging' into feat/AdditionalLogins

This commit is contained in:
Cohee
2024-10-09 01:34:29 +03:00
32 changed files with 1594 additions and 516 deletions

View File

@@ -4,7 +4,7 @@ const fetch = require('node-fetch').default;
const { jsonParser } = require('../../express-common');
const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY, OPENROUTER_HEADERS } = require('../../constants');
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4, mergeObjectWithYaml, excludeKeysByYaml, color } = require('../../util');
const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertCohereTools, convertAI21Messages } = require('../../prompt-converters');
const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertAI21Messages, mergeMessages } = require('../../prompt-converters');
const CohereStream = require('../../cohere-stream');
const { readSecret, SECRET_KEYS } = require('../secrets');
@@ -31,8 +31,11 @@ const API_AI21 = 'https://api.ai21.com/studio/v1';
*/
function postProcessPrompt(messages, type, charName, userName) {
switch (type) {
case 'merge':
case 'claude':
return convertClaudeMessages(messages, '', false, '', charName, userName).messages;
return mergeMessages(messages, charName, userName, false);
case 'strict':
return mergeMessages(messages, charName, userName, true);
default:
return messages;
}
@@ -84,7 +87,7 @@ async function sendClaudeRequest(request, response) {
const apiUrl = new URL(request.body.reverse_proxy || API_CLAUDE).toString();
const apiKey = request.body.reverse_proxy ? request.body.proxy_password : readSecret(request.user.directories, SECRET_KEYS.CLAUDE);
const divider = '-'.repeat(process.stdout.columns);
const enableSystemPromptCache = getConfigValue('claude.enableSystemPromptCache', false);
const enableSystemPromptCache = getConfigValue('claude.enableSystemPromptCache', false) && request.body.model.startsWith('claude-3');
if (!apiKey) {
console.log(color.red(`Claude API key is missing.\n${divider}`));
@@ -98,8 +101,9 @@ async function sendClaudeRequest(request, response) {
controller.abort();
});
const additionalHeaders = {};
const useTools = request.body.model.startsWith('claude-3') && Array.isArray(request.body.tools) && request.body.tools.length > 0;
const useSystemPrompt = (request.body.model.startsWith('claude-2') || request.body.model.startsWith('claude-3')) && request.body.claude_use_sysprompt;
const convertedPrompt = convertClaudeMessages(request.body.messages, request.body.assistant_prefill, useSystemPrompt, request.body.human_sysprompt_message, request.body.char_name, request.body.user_name);
const convertedPrompt = convertClaudeMessages(request.body.messages, request.body.assistant_prefill, useSystemPrompt, useTools, request.body.human_sysprompt_message, request.body.char_name, request.body.user_name);
// Add custom stop sequences
const stopSequences = [];
if (Array.isArray(request.body.stop)) {
@@ -107,7 +111,7 @@ async function sendClaudeRequest(request, response) {
}
const requestBody = {
/** @type {any} */ system: '',
/** @type {any} */ system: [],
messages: convertedPrompt.messages,
model: request.body.model,
max_tokens: request.body.max_tokens,
@@ -118,23 +122,29 @@ async function sendClaudeRequest(request, response) {
stream: request.body.stream,
};
if (useSystemPrompt) {
requestBody.system = enableSystemPromptCache
? [{ type: 'text', text: convertedPrompt.systemPrompt, cache_control: { type: 'ephemeral' } }]
: convertedPrompt.systemPrompt;
if (enableSystemPromptCache && Array.isArray(convertedPrompt.systemPrompt) && convertedPrompt.systemPrompt.length) {
convertedPrompt.systemPrompt[convertedPrompt.systemPrompt.length - 1]['cache_control'] = { type: 'ephemeral' };
}
requestBody.system = convertedPrompt.systemPrompt;
} else {
delete requestBody.system;
}
if (Array.isArray(request.body.tools) && request.body.tools.length > 0) {
if (useTools) {
// Claude doesn't do prefills on function calls, and doesn't allow empty messages
if (convertedPrompt.messages.length && convertedPrompt.messages[convertedPrompt.messages.length - 1].role === 'assistant') {
convertedPrompt.messages.push({ role: 'user', content: '.' });
}
additionalHeaders['anthropic-beta'] = 'tools-2024-05-16';
requestBody.tool_choice = { type: request.body.tool_choice === 'required' ? 'any' : 'auto' };
requestBody.tool_choice = { type: request.body.tool_choice };
requestBody.tools = request.body.tools
.filter(tool => tool.type === 'function')
.map(tool => tool.function)
.map(fn => ({ name: fn.name, description: fn.description, input_schema: fn.parameters }));
if (enableSystemPromptCache && requestBody.tools.length) {
requestBody.tools[requestBody.tools.length - 1]['cache_control'] = { type: 'ephemeral' };
}
}
if (enableSystemPromptCache) {
additionalHeaders['anthropic-beta'] = 'prompt-caching-2024-07-31';
@@ -483,7 +493,7 @@ async function sendMistralAIRequest(request, response) {
if (Array.isArray(request.body.tools) && request.body.tools.length > 0) {
requestBody['tools'] = request.body.tools;
requestBody['tool_choice'] = request.body.tool_choice === 'required' ? 'any' : 'auto';
requestBody['tool_choice'] = request.body.tool_choice;
}
const config = {
@@ -553,12 +563,6 @@ async function sendCohereRequest(request, response) {
});
}
if (Array.isArray(request.body.tools) && request.body.tools.length > 0) {
tools.push(...convertCohereTools(request.body.tools));
// Can't have both connectors and tools in the same request
connectors.splice(0, connectors.length);
}
// https://docs.cohere.com/reference/chat
const requestBody = {
stream: Boolean(request.body.stream),
@@ -908,24 +912,12 @@ router.post('/generate', jsonParser, function (request, response) {
apiKey = readSecret(request.user.directories, SECRET_KEYS.PERPLEXITY);
headers = {};
bodyParams = {};
request.body.messages = postProcessPrompt(request.body.messages, 'claude', request.body.char_name, request.body.user_name);
request.body.messages = postProcessPrompt(request.body.messages, 'strict', request.body.char_name, request.body.user_name);
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.GROQ) {
apiUrl = API_GROQ;
apiKey = readSecret(request.user.directories, SECRET_KEYS.GROQ);
headers = {};
bodyParams = {};
// 'required' tool choice is not supported by Groq
if (request.body.tool_choice === 'required') {
if (Array.isArray(request.body.tools) && request.body.tools.length > 0) {
request.body.tool_choice = request.body.tools.length > 1
? 'auto' :
{ type: 'function', function: { name: request.body.tools[0]?.function?.name } };
} else {
request.body.tool_choice = 'none';
}
}
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.ZEROONEAI) {
apiUrl = API_01AI;
apiKey = readSecret(request.user.directories, SECRET_KEYS.ZEROONEAI);
@@ -962,7 +954,7 @@ router.post('/generate', jsonParser, function (request, response) {
controller.abort();
});
if (!isTextCompletion) {
if (!isTextCompletion && Array.isArray(request.body.tools) && request.body.tools.length > 0) {
bodyParams['tools'] = request.body.tools;
bodyParams['tool_choice'] = request.body.tool_choice;
}

View File

@@ -51,7 +51,7 @@ const eratoRepPenWhitelist = [
6, 1, 11, 13, 25, 198, 12, 9, 8, 279, 264, 459, 323, 477, 539, 912, 374, 574, 1051, 1550, 1587, 4536, 5828, 15058,
3287, 3250, 1461, 1077, 813, 11074, 872, 1202, 1436, 7846, 1288, 13434, 1053, 8434, 617, 9167, 1047, 19117, 706,
12775, 649, 4250, 527, 7784, 690, 2834, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 1210, 1359, 608, 220, 596, 956,
3077, 44886, 4265, 3358, 2351, 2846, 311, 389, 315, 304, 520, 505, 430
3077, 44886, 4265, 3358, 2351, 2846, 311, 389, 315, 304, 520, 505, 430,
];
// Ban the dinkus and asterism

View File

@@ -22,6 +22,72 @@ const visitHeaders = {
'Sec-Fetch-User': '?1',
};
/**
* Extract the transcript of a YouTube video
* @param {string} videoPageBody HTML of the video page
* @param {string} lang Language code
* @returns {Promise<string>} Transcript text
*/
async function extractTranscript(videoPageBody, lang) {
const he = require('he');
const RE_XML_TRANSCRIPT = /<text start="([^"]*)" dur="([^"]*)">([^<]*)<\/text>/g;
const splittedHTML = videoPageBody.split('"captions":');
if (splittedHTML.length <= 1) {
if (videoPageBody.includes('class="g-recaptcha"')) {
throw new Error('Too many requests');
}
if (!videoPageBody.includes('"playabilityStatus":')) {
throw new Error('Video is not available');
}
throw new Error('Transcript not available');
}
const captions = (() => {
try {
return JSON.parse(splittedHTML[1].split(',"videoDetails')[0].replace('\n', ''));
} catch (e) {
return undefined;
}
})()?.['playerCaptionsTracklistRenderer'];
if (!captions) {
throw new Error('Transcript disabled');
}
if (!('captionTracks' in captions)) {
throw new Error('Transcript not available');
}
if (lang && !captions.captionTracks.some(track => track.languageCode === lang)) {
throw new Error('Transcript not available in this language');
}
const transcriptURL = (lang ? captions.captionTracks.find(track => track.languageCode === lang) : captions.captionTracks[0]).baseUrl;
const transcriptResponse = await fetch(transcriptURL, {
headers: {
...(lang && { 'Accept-Language': lang }),
'User-Agent': visitHeaders['User-Agent'],
},
});
if (!transcriptResponse.ok) {
throw new Error('Transcript request failed');
}
const transcriptBody = await transcriptResponse.text();
const results = [...transcriptBody.matchAll(RE_XML_TRANSCRIPT)];
const transcript = results.map((result) => ({
text: result[3],
duration: parseFloat(result[2]),
offset: parseFloat(result[1]),
lang: lang ?? captions.captionTracks[0].languageCode,
}));
// The text is double-encoded
const transcriptText = transcript.map((line) => he.decode(he.decode(line.text))).join(' ');
return transcriptText;
}
router.post('/serpapi', jsonParser, async (request, response) => {
try {
const key = readSecret(request.user.directories, SECRET_KEYS.SERPAPI);
@@ -56,10 +122,9 @@ router.post('/serpapi', jsonParser, async (request, response) => {
*/
router.post('/transcript', jsonParser, async (request, response) => {
try {
const he = require('he');
const RE_XML_TRANSCRIPT = /<text start="([^"]*)" dur="([^"]*)">([^<]*)<\/text>/g;
const id = request.body.id;
const lang = request.body.lang;
const json = request.body.json;
if (!id) {
console.log('Id is required for /transcript');
@@ -74,62 +139,18 @@ router.post('/transcript', jsonParser, async (request, response) => {
});
const videoPageBody = await videoPageResponse.text();
const splittedHTML = videoPageBody.split('"captions":');
if (splittedHTML.length <= 1) {
if (videoPageBody.includes('class="g-recaptcha"')) {
throw new Error('Too many requests');
try {
const transcriptText = await extractTranscript(videoPageBody, lang);
return json
? response.json({ transcript: transcriptText, html: videoPageBody })
: response.send(transcriptText);
} catch (error) {
if (json) {
return response.json({ html: videoPageBody, transcript: '' });
}
if (!videoPageBody.includes('"playabilityStatus":')) {
throw new Error('Video is not available');
}
throw new Error('Transcript not available');
throw error;
}
const captions = (() => {
try {
return JSON.parse(splittedHTML[1].split(',"videoDetails')[0].replace('\n', ''));
} catch (e) {
return undefined;
}
})()?.['playerCaptionsTracklistRenderer'];
if (!captions) {
throw new Error('Transcript disabled');
}
if (!('captionTracks' in captions)) {
throw new Error('Transcript not available');
}
if (lang && !captions.captionTracks.some(track => track.languageCode === lang)) {
throw new Error('Transcript not available in this language');
}
const transcriptURL = (lang ? captions.captionTracks.find(track => track.languageCode === lang) : captions.captionTracks[0]).baseUrl;
const transcriptResponse = await fetch(transcriptURL, {
headers: {
...(lang && { 'Accept-Language': lang }),
'User-Agent': visitHeaders['User-Agent'],
},
});
if (!transcriptResponse.ok) {
throw new Error('Transcript request failed');
}
const transcriptBody = await transcriptResponse.text();
const results = [...transcriptBody.matchAll(RE_XML_TRANSCRIPT)];
const transcript = results.map((result) => ({
text: result[3],
duration: parseFloat(result[2]),
offset: parseFloat(result[1]),
lang: lang ?? captions.captionTracks[0].languageCode,
}));
// The text is double-encoded
const transcriptText = transcript.map((line) => he.decode(he.decode(line.text))).join(' ');
return response.send(transcriptText);
} catch (error) {
console.log(error);
return response.sendStatus(500);

View File

@@ -1,5 +1,8 @@
require('./polyfill.js');
const { getConfigValue } = require('./util.js');
const crypto = require('crypto');
const PROMPT_PLACEHOLDER = getConfigValue('promptPlaceholder', 'Let\'s get started.');
/**
* Convert a prompt from the ChatML objects to the format used by Claude.
@@ -19,6 +22,14 @@ function convertClaudePrompt(messages, addAssistantPostfix, addAssistantPrefill,
//Prepare messages for claude.
//When 'Exclude Human/Assistant prefixes' checked, setting messages role to the 'system'(last message is exception).
if (messages.length > 0) {
messages.forEach((m) => {
if (!m.content) {
m.content = '';
}
if (m.tool_calls) {
m.content += JSON.stringify(m.tool_calls);
}
});
if (excludePrefixes) {
messages.slice(0, -1).forEach(message => message.role = 'system');
} else {
@@ -80,12 +91,13 @@ function convertClaudePrompt(messages, addAssistantPostfix, addAssistantPrefill,
* @param {object[]} messages Array of messages
* @param {string} prefillString User determined prefill string
* @param {boolean} useSysPrompt See if we want to use a system prompt
* @param {boolean} useTools See if we want to use tools
* @param {string} humanMsgFix Add Human message between system prompt and assistant.
* @param {string} charName Character name
* @param {string} userName User name
*/
function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFix, charName = '', userName = '') {
let systemPrompt = '';
function convertClaudeMessages(messages, prefillString, useSysPrompt, useTools, humanMsgFix, charName = '', userName = '') {
let systemPrompt = [];
if (useSysPrompt) {
// Collect all the system messages up until the first instance of a non-system message, and then remove them from the messages array.
let i;
@@ -104,7 +116,7 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFi
messages[i].content = `${charName}: ${messages[i].content}`;
}
}
systemPrompt += `${messages[i].content}\n\n`;
systemPrompt.push({ type: 'text', text: messages[i].content });
}
messages.splice(0, i);
@@ -114,12 +126,32 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFi
if (messages.length === 0 || (messages.length > 0 && messages[0].role !== 'user')) {
messages.unshift({
role: 'user',
content: humanMsgFix || '[Start a new chat]',
content: humanMsgFix || PROMPT_PLACEHOLDER,
});
}
}
// Now replace all further messages that have the role 'system' with the role 'user'. (or all if we're not using one)
const parse = (str) => typeof str === 'string' ? JSON.parse(str) : str;
messages.forEach((message) => {
if (message.role === 'assistant' && message.tool_calls) {
message.content = message.tool_calls.map((tc) => ({
type: 'tool_use',
id: tc.id,
name: tc.function.name,
input: parse(tc.function.arguments),
}));
}
if (message.role === 'tool') {
message.role = 'user';
message.content = [{
type: 'tool_result',
tool_use_id: message.tool_call_id,
content: message.content,
}];
}
if (message.role === 'system') {
if (userName && message.name === 'example_user') {
message.content = `${userName}: ${message.content}`;
@@ -128,14 +160,81 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFi
message.content = `${charName}: ${message.content}`;
}
message.role = 'user';
// Delete name here so it doesn't get added later
delete message.name;
}
// Convert everything to an array of it would be easier to work with
if (typeof message.content === 'string') {
// Take care of name properties since claude messages don't support them
if (message.name) {
message.content = `${message.name}: ${message.content}`;
}
message.content = [{ type: 'text', text: message.content }];
} else if (Array.isArray(message.content)) {
message.content = message.content.map((content) => {
if (content.type === 'image_url') {
const imageEntry = content?.image_url;
const imageData = imageEntry?.url;
const mimeType = imageData?.split(';')?.[0].split(':')?.[1];
const base64Data = imageData?.split(',')?.[1];
return {
type: 'image',
source: {
type: 'base64',
media_type: mimeType,
data: base64Data,
},
};
}
if (content.type === 'text') {
if (message.name) {
content.text = `${message.name}: ${content.text}`;
}
return content;
}
return content;
});
}
// Remove offending properties
delete message.name;
delete message.tool_calls;
delete message.tool_call_id;
});
// Images in assistant messages should be moved to the next user message
for (let i = 0; i < messages.length; i++) {
if (messages[i].role === 'assistant' && messages[i].content.some(c => c.type === 'image')) {
// Find the next user message
let j = i + 1;
while (j < messages.length && messages[j].role !== 'user') {
j++;
}
// Move the images
if (j >= messages.length) {
// If there is no user message after the assistant message, add a new one
messages.splice(i + 1, 0, { role: 'user', content: [] });
}
messages[j].content.push(...messages[i].content.filter(c => c.type === 'image'));
messages[i].content = messages[i].content.filter(c => c.type !== 'image');
}
}
// Shouldn't be conditional anymore, messages api expects the last role to be user unless we're explicitly prefilling
if (prefillString) {
messages.push({
role: 'assistant',
content: prefillString.trimEnd(),
// Dangling whitespace are not allowed for prefilling
content: [{ type: 'text', text: prefillString.trimEnd() }],
});
}
@@ -143,53 +242,34 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFi
// Also handle multi-modality, holy slop.
let mergedMessages = [];
messages.forEach((message) => {
const imageEntry = message.content?.[1]?.image_url;
const imageData = imageEntry?.url;
const mimeType = imageData?.split(';')?.[0].split(':')?.[1];
const base64Data = imageData?.split(',')?.[1];
// Take care of name properties since claude messages don't support them
if (message.name) {
if (Array.isArray(message.content)) {
message.content[0].text = `${message.name}: ${message.content[0].text}`;
} else {
message.content = `${message.name}: ${message.content}`;
}
delete message.name;
}
if (mergedMessages.length > 0 && mergedMessages[mergedMessages.length - 1].role === message.role) {
if (Array.isArray(message.content)) {
if (Array.isArray(mergedMessages[mergedMessages.length - 1].content)) {
mergedMessages[mergedMessages.length - 1].content[0].text += '\n\n' + message.content[0].text;
} else {
mergedMessages[mergedMessages.length - 1].content += '\n\n' + message.content[0].text;
}
} else {
if (Array.isArray(mergedMessages[mergedMessages.length - 1].content)) {
mergedMessages[mergedMessages.length - 1].content[0].text += '\n\n' + message.content;
} else {
mergedMessages[mergedMessages.length - 1].content += '\n\n' + message.content;
}
}
mergedMessages[mergedMessages.length - 1].content.push(...message.content);
} else {
mergedMessages.push(message);
}
if (imageData) {
mergedMessages[mergedMessages.length - 1].content = [
{ type: 'text', text: mergedMessages[mergedMessages.length - 1].content[0]?.text || mergedMessages[mergedMessages.length - 1].content },
{
type: 'image', source: {
type: 'base64',
media_type: mimeType,
data: base64Data,
},
},
];
}
});
return { messages: mergedMessages, systemPrompt: systemPrompt.trim() };
if (!useTools) {
mergedMessages.forEach((message) => {
message.content.forEach((content) => {
if (content.type === 'tool_use') {
content.type = 'text';
content.text = JSON.stringify(content.input);
delete content.id;
delete content.name;
delete content.input;
}
if (content.type === 'tool_result') {
content.type = 'text';
content.text = content.content;
delete content.tool_use_id;
delete content.content;
}
});
});
}
return { messages: mergedMessages, systemPrompt: systemPrompt };
}
/**
@@ -205,7 +285,6 @@ function convertCohereMessages(messages, charName = '', userName = '') {
'user': 'USER',
'assistant': 'CHATBOT',
};
const placeholder = '[Start a new chat]';
let systemPrompt = '';
// Collect all the system messages up until the first instance of a non-system message, and then remove them from the messages array.
@@ -233,12 +312,12 @@ function convertCohereMessages(messages, charName = '', userName = '') {
if (messages.length === 0) {
messages.unshift({
role: 'user',
content: placeholder,
content: PROMPT_PLACEHOLDER,
});
}
const lastNonSystemMessageIndex = messages.findLastIndex(msg => msg.role === 'user' || msg.role === 'assistant');
const userPrompt = messages.slice(lastNonSystemMessageIndex).map(msg => msg.content).join('\n\n') || placeholder;
const userPrompt = messages.slice(lastNonSystemMessageIndex).map(msg => msg.content).join('\n\n') || PROMPT_PLACEHOLDER;
const chatHistory = messages.slice(0, lastNonSystemMessageIndex).map(msg => {
return {
@@ -414,7 +493,7 @@ function convertAI21Messages(messages, charName = '', userName = '') {
if (messages.length === 0) {
messages.unshift({
role: 'user',
content: '[Start a new chat]',
content: PROMPT_PLACEHOLDER,
});
}
@@ -466,8 +545,18 @@ function convertMistralMessages(messages, charName = '', userName = '') {
lastMsg.prefix = true;
}
const sanitizeToolId = (id) => crypto.createHash('sha512').update(id).digest('hex').slice(0, 9);
// Doesn't support completion names, so prepend if not already done by the frontend (e.g. for group chats).
messages.forEach(msg => {
if ('tool_calls' in msg && Array.isArray(msg.tool_calls)) {
msg.tool_calls.forEach(tool => {
tool.id = sanitizeToolId(tool.id);
});
}
if ('tool_call_id' in msg && msg.role === 'tool') {
msg.tool_call_id = sanitizeToolId(msg.tool_call_id);
}
if (msg.role === 'system' && msg.name === 'example_assistant') {
if (charName && !msg.content.startsWith(`${charName}: `)) {
msg.content = `${charName}: ${msg.content}`;
@@ -488,6 +577,28 @@ function convertMistralMessages(messages, charName = '', userName = '') {
}
});
// If user role message immediately follows a tool message, append it to the last user message
const fixToolMessages = () => {
let rerun = true;
while (rerun) {
rerun = false;
messages.forEach((message, i) => {
if (i === messages.length - 1) {
return;
}
if (message.role === 'tool' && messages[i + 1].role === 'user') {
const lastUserMessage = messages.slice(0, i).findLastIndex(m => m.role === 'user' && m.content);
if (lastUserMessage !== -1) {
messages[lastUserMessage].content += '\n\n' + messages[i + 1].content;
messages.splice(i + 1, 1);
rerun = true;
}
}
});
}
};
fixToolMessages();
// If system role message immediately follows an assistant message, change its role to user
for (let i = 0; i < messages.length - 1; i++) {
if (messages[i].role === 'assistant' && messages[i + 1].role === 'system') {
@@ -498,6 +609,83 @@ function convertMistralMessages(messages, charName = '', userName = '') {
return messages;
}
/**
* Merge messages with the same consecutive role, removing names if they exist.
* @param {any[]} messages Messages to merge
* @param {string} charName Character name
* @param {string} userName User name
* @param {boolean} strict Enable strict mode: only allow one system message at the start, force user first message
* @returns {any[]} Merged messages
*/
function mergeMessages(messages, charName, userName, strict) {
let mergedMessages = [];
// Remove names from the messages
messages.forEach((message) => {
if (!message.content) {
message.content = '';
}
if (message.role === 'system' && message.name === 'example_assistant') {
if (charName && !message.content.startsWith(`${charName}: `)) {
message.content = `${charName}: ${message.content}`;
}
}
if (message.role === 'system' && message.name === 'example_user') {
if (userName && !message.content.startsWith(`${userName}: `)) {
message.content = `${userName}: ${message.content}`;
}
}
if (message.name && message.role !== 'system') {
if (!message.content.startsWith(`${message.name}: `)) {
message.content = `${message.name}: ${message.content}`;
}
}
if (message.role === 'tool') {
message.role = 'user';
}
delete message.name;
delete message.tool_calls;
delete message.tool_call_id;
});
// Squash consecutive messages with the same role
messages.forEach((message) => {
if (mergedMessages.length > 0 && mergedMessages[mergedMessages.length - 1].role === message.role && message.content) {
mergedMessages[mergedMessages.length - 1].content += '\n\n' + message.content;
} else {
mergedMessages.push(message);
}
});
// Prevent erroring out if the messages array is empty.
if (messages.length === 0) {
messages.unshift({
role: 'user',
content: PROMPT_PLACEHOLDER,
});
}
if (strict) {
for (let i = 0; i < mergedMessages.length; i++) {
// Force mid-prompt system messages to be user messages
if (i > 0 && mergedMessages[i].role === 'system') {
mergedMessages[i].role = 'user';
}
}
if (mergedMessages.length) {
if (mergedMessages[0].role === 'system' && (mergedMessages.length === 1 || mergedMessages[1].role !== 'user')) {
mergedMessages.splice(1, 0, { role: 'user', content: PROMPT_PLACEHOLDER });
}
else if (mergedMessages[0].role !== 'system' && mergedMessages[0].role !== 'user') {
mergedMessages.unshift({ role: 'user', content: PROMPT_PLACEHOLDER });
}
}
return mergeMessages(mergedMessages, charName, userName, false);
}
return mergedMessages;
}
/**
* Convert a prompt from the ChatML objects to the format used by Text Completion API.
* @param {object[]} messages Array of messages
@@ -523,76 +711,6 @@ function convertTextCompletionPrompt(messages) {
return messageStrings.join('\n') + '\nassistant:';
}
/**
* Convert OpenAI Chat Completion tools to the format used by Cohere.
* @param {object[]} tools OpenAI Chat Completion tool definitions
*/
function convertCohereTools(tools) {
if (!Array.isArray(tools) || tools.length === 0) {
return [];
}
const jsonSchemaToPythonTypes = {
'string': 'str',
'number': 'float',
'integer': 'int',
'boolean': 'bool',
'array': 'list',
'object': 'dict',
};
const cohereTools = [];
for (const tool of tools) {
if (tool?.type !== 'function') {
console.log(`Unsupported tool type: ${tool.type}`);
continue;
}
const name = tool?.function?.name;
const description = tool?.function?.description;
const properties = tool?.function?.parameters?.properties;
const required = tool?.function?.parameters?.required;
const parameters = {};
if (!name) {
console.log('Tool name is missing');
continue;
}
if (!description) {
console.log('Tool description is missing');
}
if (!properties || typeof properties !== 'object') {
console.log(`No properties found for tool: ${tool?.function?.name}`);
continue;
}
for (const property in properties) {
const parameterDefinition = properties[property];
const description = parameterDefinition.description || (parameterDefinition.enum ? JSON.stringify(parameterDefinition.enum) : '');
const type = jsonSchemaToPythonTypes[parameterDefinition.type] || 'str';
const isRequired = Array.isArray(required) && required.includes(property);
parameters[property] = {
description: description,
type: type,
required: isRequired,
};
}
const cohereTool = {
name: tool.function.name,
description: tool.function.description,
parameter_definitions: parameters,
};
cohereTools.push(cohereTool);
}
return cohereTools;
}
module.exports = {
convertClaudePrompt,
convertClaudeMessages,
@@ -600,6 +718,6 @@ module.exports = {
convertTextCompletionPrompt,
convertCohereMessages,
convertMistralMessages,
convertCohereTools,
convertAI21Messages,
mergeMessages,
};