mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-06-05 21:59:27 +02:00
Merge branch 'staging' into feat/AdditionalLogins
This commit is contained in:
@@ -4,7 +4,7 @@ const fetch = require('node-fetch').default;
|
||||
const { jsonParser } = require('../../express-common');
|
||||
const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY, OPENROUTER_HEADERS } = require('../../constants');
|
||||
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4, mergeObjectWithYaml, excludeKeysByYaml, color } = require('../../util');
|
||||
const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertCohereTools, convertAI21Messages } = require('../../prompt-converters');
|
||||
const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertAI21Messages, mergeMessages } = require('../../prompt-converters');
|
||||
const CohereStream = require('../../cohere-stream');
|
||||
|
||||
const { readSecret, SECRET_KEYS } = require('../secrets');
|
||||
@@ -31,8 +31,11 @@ const API_AI21 = 'https://api.ai21.com/studio/v1';
|
||||
*/
|
||||
function postProcessPrompt(messages, type, charName, userName) {
|
||||
switch (type) {
|
||||
case 'merge':
|
||||
case 'claude':
|
||||
return convertClaudeMessages(messages, '', false, '', charName, userName).messages;
|
||||
return mergeMessages(messages, charName, userName, false);
|
||||
case 'strict':
|
||||
return mergeMessages(messages, charName, userName, true);
|
||||
default:
|
||||
return messages;
|
||||
}
|
||||
@@ -84,7 +87,7 @@ async function sendClaudeRequest(request, response) {
|
||||
const apiUrl = new URL(request.body.reverse_proxy || API_CLAUDE).toString();
|
||||
const apiKey = request.body.reverse_proxy ? request.body.proxy_password : readSecret(request.user.directories, SECRET_KEYS.CLAUDE);
|
||||
const divider = '-'.repeat(process.stdout.columns);
|
||||
const enableSystemPromptCache = getConfigValue('claude.enableSystemPromptCache', false);
|
||||
const enableSystemPromptCache = getConfigValue('claude.enableSystemPromptCache', false) && request.body.model.startsWith('claude-3');
|
||||
|
||||
if (!apiKey) {
|
||||
console.log(color.red(`Claude API key is missing.\n${divider}`));
|
||||
@@ -98,8 +101,9 @@ async function sendClaudeRequest(request, response) {
|
||||
controller.abort();
|
||||
});
|
||||
const additionalHeaders = {};
|
||||
const useTools = request.body.model.startsWith('claude-3') && Array.isArray(request.body.tools) && request.body.tools.length > 0;
|
||||
const useSystemPrompt = (request.body.model.startsWith('claude-2') || request.body.model.startsWith('claude-3')) && request.body.claude_use_sysprompt;
|
||||
const convertedPrompt = convertClaudeMessages(request.body.messages, request.body.assistant_prefill, useSystemPrompt, request.body.human_sysprompt_message, request.body.char_name, request.body.user_name);
|
||||
const convertedPrompt = convertClaudeMessages(request.body.messages, request.body.assistant_prefill, useSystemPrompt, useTools, request.body.human_sysprompt_message, request.body.char_name, request.body.user_name);
|
||||
// Add custom stop sequences
|
||||
const stopSequences = [];
|
||||
if (Array.isArray(request.body.stop)) {
|
||||
@@ -107,7 +111,7 @@ async function sendClaudeRequest(request, response) {
|
||||
}
|
||||
|
||||
const requestBody = {
|
||||
/** @type {any} */ system: '',
|
||||
/** @type {any} */ system: [],
|
||||
messages: convertedPrompt.messages,
|
||||
model: request.body.model,
|
||||
max_tokens: request.body.max_tokens,
|
||||
@@ -118,23 +122,29 @@ async function sendClaudeRequest(request, response) {
|
||||
stream: request.body.stream,
|
||||
};
|
||||
if (useSystemPrompt) {
|
||||
requestBody.system = enableSystemPromptCache
|
||||
? [{ type: 'text', text: convertedPrompt.systemPrompt, cache_control: { type: 'ephemeral' } }]
|
||||
: convertedPrompt.systemPrompt;
|
||||
if (enableSystemPromptCache && Array.isArray(convertedPrompt.systemPrompt) && convertedPrompt.systemPrompt.length) {
|
||||
convertedPrompt.systemPrompt[convertedPrompt.systemPrompt.length - 1]['cache_control'] = { type: 'ephemeral' };
|
||||
}
|
||||
|
||||
requestBody.system = convertedPrompt.systemPrompt;
|
||||
} else {
|
||||
delete requestBody.system;
|
||||
}
|
||||
if (Array.isArray(request.body.tools) && request.body.tools.length > 0) {
|
||||
if (useTools) {
|
||||
// Claude doesn't do prefills on function calls, and doesn't allow empty messages
|
||||
if (convertedPrompt.messages.length && convertedPrompt.messages[convertedPrompt.messages.length - 1].role === 'assistant') {
|
||||
convertedPrompt.messages.push({ role: 'user', content: '.' });
|
||||
}
|
||||
additionalHeaders['anthropic-beta'] = 'tools-2024-05-16';
|
||||
requestBody.tool_choice = { type: request.body.tool_choice === 'required' ? 'any' : 'auto' };
|
||||
requestBody.tool_choice = { type: request.body.tool_choice };
|
||||
requestBody.tools = request.body.tools
|
||||
.filter(tool => tool.type === 'function')
|
||||
.map(tool => tool.function)
|
||||
.map(fn => ({ name: fn.name, description: fn.description, input_schema: fn.parameters }));
|
||||
|
||||
if (enableSystemPromptCache && requestBody.tools.length) {
|
||||
requestBody.tools[requestBody.tools.length - 1]['cache_control'] = { type: 'ephemeral' };
|
||||
}
|
||||
}
|
||||
if (enableSystemPromptCache) {
|
||||
additionalHeaders['anthropic-beta'] = 'prompt-caching-2024-07-31';
|
||||
@@ -483,7 +493,7 @@ async function sendMistralAIRequest(request, response) {
|
||||
|
||||
if (Array.isArray(request.body.tools) && request.body.tools.length > 0) {
|
||||
requestBody['tools'] = request.body.tools;
|
||||
requestBody['tool_choice'] = request.body.tool_choice === 'required' ? 'any' : 'auto';
|
||||
requestBody['tool_choice'] = request.body.tool_choice;
|
||||
}
|
||||
|
||||
const config = {
|
||||
@@ -553,12 +563,6 @@ async function sendCohereRequest(request, response) {
|
||||
});
|
||||
}
|
||||
|
||||
if (Array.isArray(request.body.tools) && request.body.tools.length > 0) {
|
||||
tools.push(...convertCohereTools(request.body.tools));
|
||||
// Can't have both connectors and tools in the same request
|
||||
connectors.splice(0, connectors.length);
|
||||
}
|
||||
|
||||
// https://docs.cohere.com/reference/chat
|
||||
const requestBody = {
|
||||
stream: Boolean(request.body.stream),
|
||||
@@ -908,24 +912,12 @@ router.post('/generate', jsonParser, function (request, response) {
|
||||
apiKey = readSecret(request.user.directories, SECRET_KEYS.PERPLEXITY);
|
||||
headers = {};
|
||||
bodyParams = {};
|
||||
request.body.messages = postProcessPrompt(request.body.messages, 'claude', request.body.char_name, request.body.user_name);
|
||||
request.body.messages = postProcessPrompt(request.body.messages, 'strict', request.body.char_name, request.body.user_name);
|
||||
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.GROQ) {
|
||||
apiUrl = API_GROQ;
|
||||
apiKey = readSecret(request.user.directories, SECRET_KEYS.GROQ);
|
||||
headers = {};
|
||||
bodyParams = {};
|
||||
|
||||
// 'required' tool choice is not supported by Groq
|
||||
if (request.body.tool_choice === 'required') {
|
||||
if (Array.isArray(request.body.tools) && request.body.tools.length > 0) {
|
||||
request.body.tool_choice = request.body.tools.length > 1
|
||||
? 'auto' :
|
||||
{ type: 'function', function: { name: request.body.tools[0]?.function?.name } };
|
||||
|
||||
} else {
|
||||
request.body.tool_choice = 'none';
|
||||
}
|
||||
}
|
||||
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.ZEROONEAI) {
|
||||
apiUrl = API_01AI;
|
||||
apiKey = readSecret(request.user.directories, SECRET_KEYS.ZEROONEAI);
|
||||
@@ -962,7 +954,7 @@ router.post('/generate', jsonParser, function (request, response) {
|
||||
controller.abort();
|
||||
});
|
||||
|
||||
if (!isTextCompletion) {
|
||||
if (!isTextCompletion && Array.isArray(request.body.tools) && request.body.tools.length > 0) {
|
||||
bodyParams['tools'] = request.body.tools;
|
||||
bodyParams['tool_choice'] = request.body.tool_choice;
|
||||
}
|
||||
|
@@ -51,7 +51,7 @@ const eratoRepPenWhitelist = [
|
||||
6, 1, 11, 13, 25, 198, 12, 9, 8, 279, 264, 459, 323, 477, 539, 912, 374, 574, 1051, 1550, 1587, 4536, 5828, 15058,
|
||||
3287, 3250, 1461, 1077, 813, 11074, 872, 1202, 1436, 7846, 1288, 13434, 1053, 8434, 617, 9167, 1047, 19117, 706,
|
||||
12775, 649, 4250, 527, 7784, 690, 2834, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 1210, 1359, 608, 220, 596, 956,
|
||||
3077, 44886, 4265, 3358, 2351, 2846, 311, 389, 315, 304, 520, 505, 430
|
||||
3077, 44886, 4265, 3358, 2351, 2846, 311, 389, 315, 304, 520, 505, 430,
|
||||
];
|
||||
|
||||
// Ban the dinkus and asterism
|
||||
|
@@ -22,6 +22,72 @@ const visitHeaders = {
|
||||
'Sec-Fetch-User': '?1',
|
||||
};
|
||||
|
||||
/**
|
||||
* Extract the transcript of a YouTube video
|
||||
* @param {string} videoPageBody HTML of the video page
|
||||
* @param {string} lang Language code
|
||||
* @returns {Promise<string>} Transcript text
|
||||
*/
|
||||
async function extractTranscript(videoPageBody, lang) {
|
||||
const he = require('he');
|
||||
const RE_XML_TRANSCRIPT = /<text start="([^"]*)" dur="([^"]*)">([^<]*)<\/text>/g;
|
||||
const splittedHTML = videoPageBody.split('"captions":');
|
||||
|
||||
if (splittedHTML.length <= 1) {
|
||||
if (videoPageBody.includes('class="g-recaptcha"')) {
|
||||
throw new Error('Too many requests');
|
||||
}
|
||||
if (!videoPageBody.includes('"playabilityStatus":')) {
|
||||
throw new Error('Video is not available');
|
||||
}
|
||||
throw new Error('Transcript not available');
|
||||
}
|
||||
|
||||
const captions = (() => {
|
||||
try {
|
||||
return JSON.parse(splittedHTML[1].split(',"videoDetails')[0].replace('\n', ''));
|
||||
} catch (e) {
|
||||
return undefined;
|
||||
}
|
||||
})()?.['playerCaptionsTracklistRenderer'];
|
||||
|
||||
if (!captions) {
|
||||
throw new Error('Transcript disabled');
|
||||
}
|
||||
|
||||
if (!('captionTracks' in captions)) {
|
||||
throw new Error('Transcript not available');
|
||||
}
|
||||
|
||||
if (lang && !captions.captionTracks.some(track => track.languageCode === lang)) {
|
||||
throw new Error('Transcript not available in this language');
|
||||
}
|
||||
|
||||
const transcriptURL = (lang ? captions.captionTracks.find(track => track.languageCode === lang) : captions.captionTracks[0]).baseUrl;
|
||||
const transcriptResponse = await fetch(transcriptURL, {
|
||||
headers: {
|
||||
...(lang && { 'Accept-Language': lang }),
|
||||
'User-Agent': visitHeaders['User-Agent'],
|
||||
},
|
||||
});
|
||||
|
||||
if (!transcriptResponse.ok) {
|
||||
throw new Error('Transcript request failed');
|
||||
}
|
||||
|
||||
const transcriptBody = await transcriptResponse.text();
|
||||
const results = [...transcriptBody.matchAll(RE_XML_TRANSCRIPT)];
|
||||
const transcript = results.map((result) => ({
|
||||
text: result[3],
|
||||
duration: parseFloat(result[2]),
|
||||
offset: parseFloat(result[1]),
|
||||
lang: lang ?? captions.captionTracks[0].languageCode,
|
||||
}));
|
||||
// The text is double-encoded
|
||||
const transcriptText = transcript.map((line) => he.decode(he.decode(line.text))).join(' ');
|
||||
return transcriptText;
|
||||
}
|
||||
|
||||
router.post('/serpapi', jsonParser, async (request, response) => {
|
||||
try {
|
||||
const key = readSecret(request.user.directories, SECRET_KEYS.SERPAPI);
|
||||
@@ -56,10 +122,9 @@ router.post('/serpapi', jsonParser, async (request, response) => {
|
||||
*/
|
||||
router.post('/transcript', jsonParser, async (request, response) => {
|
||||
try {
|
||||
const he = require('he');
|
||||
const RE_XML_TRANSCRIPT = /<text start="([^"]*)" dur="([^"]*)">([^<]*)<\/text>/g;
|
||||
const id = request.body.id;
|
||||
const lang = request.body.lang;
|
||||
const json = request.body.json;
|
||||
|
||||
if (!id) {
|
||||
console.log('Id is required for /transcript');
|
||||
@@ -74,62 +139,18 @@ router.post('/transcript', jsonParser, async (request, response) => {
|
||||
});
|
||||
|
||||
const videoPageBody = await videoPageResponse.text();
|
||||
const splittedHTML = videoPageBody.split('"captions":');
|
||||
|
||||
if (splittedHTML.length <= 1) {
|
||||
if (videoPageBody.includes('class="g-recaptcha"')) {
|
||||
throw new Error('Too many requests');
|
||||
try {
|
||||
const transcriptText = await extractTranscript(videoPageBody, lang);
|
||||
return json
|
||||
? response.json({ transcript: transcriptText, html: videoPageBody })
|
||||
: response.send(transcriptText);
|
||||
} catch (error) {
|
||||
if (json) {
|
||||
return response.json({ html: videoPageBody, transcript: '' });
|
||||
}
|
||||
if (!videoPageBody.includes('"playabilityStatus":')) {
|
||||
throw new Error('Video is not available');
|
||||
}
|
||||
throw new Error('Transcript not available');
|
||||
throw error;
|
||||
}
|
||||
|
||||
const captions = (() => {
|
||||
try {
|
||||
return JSON.parse(splittedHTML[1].split(',"videoDetails')[0].replace('\n', ''));
|
||||
} catch (e) {
|
||||
return undefined;
|
||||
}
|
||||
})()?.['playerCaptionsTracklistRenderer'];
|
||||
|
||||
if (!captions) {
|
||||
throw new Error('Transcript disabled');
|
||||
}
|
||||
|
||||
if (!('captionTracks' in captions)) {
|
||||
throw new Error('Transcript not available');
|
||||
}
|
||||
|
||||
if (lang && !captions.captionTracks.some(track => track.languageCode === lang)) {
|
||||
throw new Error('Transcript not available in this language');
|
||||
}
|
||||
|
||||
const transcriptURL = (lang ? captions.captionTracks.find(track => track.languageCode === lang) : captions.captionTracks[0]).baseUrl;
|
||||
const transcriptResponse = await fetch(transcriptURL, {
|
||||
headers: {
|
||||
...(lang && { 'Accept-Language': lang }),
|
||||
'User-Agent': visitHeaders['User-Agent'],
|
||||
},
|
||||
});
|
||||
|
||||
if (!transcriptResponse.ok) {
|
||||
throw new Error('Transcript request failed');
|
||||
}
|
||||
|
||||
const transcriptBody = await transcriptResponse.text();
|
||||
const results = [...transcriptBody.matchAll(RE_XML_TRANSCRIPT)];
|
||||
const transcript = results.map((result) => ({
|
||||
text: result[3],
|
||||
duration: parseFloat(result[2]),
|
||||
offset: parseFloat(result[1]),
|
||||
lang: lang ?? captions.captionTracks[0].languageCode,
|
||||
}));
|
||||
// The text is double-encoded
|
||||
const transcriptText = transcript.map((line) => he.decode(he.decode(line.text))).join(' ');
|
||||
|
||||
return response.send(transcriptText);
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
return response.sendStatus(500);
|
||||
|
@@ -1,5 +1,8 @@
|
||||
require('./polyfill.js');
|
||||
const { getConfigValue } = require('./util.js');
|
||||
const crypto = require('crypto');
|
||||
|
||||
const PROMPT_PLACEHOLDER = getConfigValue('promptPlaceholder', 'Let\'s get started.');
|
||||
|
||||
/**
|
||||
* Convert a prompt from the ChatML objects to the format used by Claude.
|
||||
@@ -19,6 +22,14 @@ function convertClaudePrompt(messages, addAssistantPostfix, addAssistantPrefill,
|
||||
//Prepare messages for claude.
|
||||
//When 'Exclude Human/Assistant prefixes' checked, setting messages role to the 'system'(last message is exception).
|
||||
if (messages.length > 0) {
|
||||
messages.forEach((m) => {
|
||||
if (!m.content) {
|
||||
m.content = '';
|
||||
}
|
||||
if (m.tool_calls) {
|
||||
m.content += JSON.stringify(m.tool_calls);
|
||||
}
|
||||
});
|
||||
if (excludePrefixes) {
|
||||
messages.slice(0, -1).forEach(message => message.role = 'system');
|
||||
} else {
|
||||
@@ -80,12 +91,13 @@ function convertClaudePrompt(messages, addAssistantPostfix, addAssistantPrefill,
|
||||
* @param {object[]} messages Array of messages
|
||||
* @param {string} prefillString User determined prefill string
|
||||
* @param {boolean} useSysPrompt See if we want to use a system prompt
|
||||
* @param {boolean} useTools See if we want to use tools
|
||||
* @param {string} humanMsgFix Add Human message between system prompt and assistant.
|
||||
* @param {string} charName Character name
|
||||
* @param {string} userName User name
|
||||
*/
|
||||
function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFix, charName = '', userName = '') {
|
||||
let systemPrompt = '';
|
||||
function convertClaudeMessages(messages, prefillString, useSysPrompt, useTools, humanMsgFix, charName = '', userName = '') {
|
||||
let systemPrompt = [];
|
||||
if (useSysPrompt) {
|
||||
// Collect all the system messages up until the first instance of a non-system message, and then remove them from the messages array.
|
||||
let i;
|
||||
@@ -104,7 +116,7 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFi
|
||||
messages[i].content = `${charName}: ${messages[i].content}`;
|
||||
}
|
||||
}
|
||||
systemPrompt += `${messages[i].content}\n\n`;
|
||||
systemPrompt.push({ type: 'text', text: messages[i].content });
|
||||
}
|
||||
|
||||
messages.splice(0, i);
|
||||
@@ -114,12 +126,32 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFi
|
||||
if (messages.length === 0 || (messages.length > 0 && messages[0].role !== 'user')) {
|
||||
messages.unshift({
|
||||
role: 'user',
|
||||
content: humanMsgFix || '[Start a new chat]',
|
||||
content: humanMsgFix || PROMPT_PLACEHOLDER,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Now replace all further messages that have the role 'system' with the role 'user'. (or all if we're not using one)
|
||||
const parse = (str) => typeof str === 'string' ? JSON.parse(str) : str;
|
||||
messages.forEach((message) => {
|
||||
if (message.role === 'assistant' && message.tool_calls) {
|
||||
message.content = message.tool_calls.map((tc) => ({
|
||||
type: 'tool_use',
|
||||
id: tc.id,
|
||||
name: tc.function.name,
|
||||
input: parse(tc.function.arguments),
|
||||
}));
|
||||
}
|
||||
|
||||
if (message.role === 'tool') {
|
||||
message.role = 'user';
|
||||
message.content = [{
|
||||
type: 'tool_result',
|
||||
tool_use_id: message.tool_call_id,
|
||||
content: message.content,
|
||||
}];
|
||||
}
|
||||
|
||||
if (message.role === 'system') {
|
||||
if (userName && message.name === 'example_user') {
|
||||
message.content = `${userName}: ${message.content}`;
|
||||
@@ -128,14 +160,81 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFi
|
||||
message.content = `${charName}: ${message.content}`;
|
||||
}
|
||||
message.role = 'user';
|
||||
|
||||
// Delete name here so it doesn't get added later
|
||||
delete message.name;
|
||||
}
|
||||
|
||||
// Convert everything to an array of it would be easier to work with
|
||||
if (typeof message.content === 'string') {
|
||||
// Take care of name properties since claude messages don't support them
|
||||
if (message.name) {
|
||||
message.content = `${message.name}: ${message.content}`;
|
||||
}
|
||||
|
||||
message.content = [{ type: 'text', text: message.content }];
|
||||
} else if (Array.isArray(message.content)) {
|
||||
message.content = message.content.map((content) => {
|
||||
if (content.type === 'image_url') {
|
||||
const imageEntry = content?.image_url;
|
||||
const imageData = imageEntry?.url;
|
||||
const mimeType = imageData?.split(';')?.[0].split(':')?.[1];
|
||||
const base64Data = imageData?.split(',')?.[1];
|
||||
|
||||
return {
|
||||
type: 'image',
|
||||
source: {
|
||||
type: 'base64',
|
||||
media_type: mimeType,
|
||||
data: base64Data,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
if (content.type === 'text') {
|
||||
if (message.name) {
|
||||
content.text = `${message.name}: ${content.text}`;
|
||||
}
|
||||
|
||||
return content;
|
||||
}
|
||||
|
||||
return content;
|
||||
});
|
||||
}
|
||||
|
||||
// Remove offending properties
|
||||
delete message.name;
|
||||
delete message.tool_calls;
|
||||
delete message.tool_call_id;
|
||||
});
|
||||
|
||||
// Images in assistant messages should be moved to the next user message
|
||||
for (let i = 0; i < messages.length; i++) {
|
||||
if (messages[i].role === 'assistant' && messages[i].content.some(c => c.type === 'image')) {
|
||||
// Find the next user message
|
||||
let j = i + 1;
|
||||
while (j < messages.length && messages[j].role !== 'user') {
|
||||
j++;
|
||||
}
|
||||
|
||||
// Move the images
|
||||
if (j >= messages.length) {
|
||||
// If there is no user message after the assistant message, add a new one
|
||||
messages.splice(i + 1, 0, { role: 'user', content: [] });
|
||||
}
|
||||
|
||||
messages[j].content.push(...messages[i].content.filter(c => c.type === 'image'));
|
||||
messages[i].content = messages[i].content.filter(c => c.type !== 'image');
|
||||
}
|
||||
}
|
||||
|
||||
// Shouldn't be conditional anymore, messages api expects the last role to be user unless we're explicitly prefilling
|
||||
if (prefillString) {
|
||||
messages.push({
|
||||
role: 'assistant',
|
||||
content: prefillString.trimEnd(),
|
||||
// Dangling whitespace are not allowed for prefilling
|
||||
content: [{ type: 'text', text: prefillString.trimEnd() }],
|
||||
});
|
||||
}
|
||||
|
||||
@@ -143,53 +242,34 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFi
|
||||
// Also handle multi-modality, holy slop.
|
||||
let mergedMessages = [];
|
||||
messages.forEach((message) => {
|
||||
const imageEntry = message.content?.[1]?.image_url;
|
||||
const imageData = imageEntry?.url;
|
||||
const mimeType = imageData?.split(';')?.[0].split(':')?.[1];
|
||||
const base64Data = imageData?.split(',')?.[1];
|
||||
|
||||
// Take care of name properties since claude messages don't support them
|
||||
if (message.name) {
|
||||
if (Array.isArray(message.content)) {
|
||||
message.content[0].text = `${message.name}: ${message.content[0].text}`;
|
||||
} else {
|
||||
message.content = `${message.name}: ${message.content}`;
|
||||
}
|
||||
delete message.name;
|
||||
}
|
||||
|
||||
if (mergedMessages.length > 0 && mergedMessages[mergedMessages.length - 1].role === message.role) {
|
||||
if (Array.isArray(message.content)) {
|
||||
if (Array.isArray(mergedMessages[mergedMessages.length - 1].content)) {
|
||||
mergedMessages[mergedMessages.length - 1].content[0].text += '\n\n' + message.content[0].text;
|
||||
} else {
|
||||
mergedMessages[mergedMessages.length - 1].content += '\n\n' + message.content[0].text;
|
||||
}
|
||||
} else {
|
||||
if (Array.isArray(mergedMessages[mergedMessages.length - 1].content)) {
|
||||
mergedMessages[mergedMessages.length - 1].content[0].text += '\n\n' + message.content;
|
||||
} else {
|
||||
mergedMessages[mergedMessages.length - 1].content += '\n\n' + message.content;
|
||||
}
|
||||
}
|
||||
mergedMessages[mergedMessages.length - 1].content.push(...message.content);
|
||||
} else {
|
||||
mergedMessages.push(message);
|
||||
}
|
||||
if (imageData) {
|
||||
mergedMessages[mergedMessages.length - 1].content = [
|
||||
{ type: 'text', text: mergedMessages[mergedMessages.length - 1].content[0]?.text || mergedMessages[mergedMessages.length - 1].content },
|
||||
{
|
||||
type: 'image', source: {
|
||||
type: 'base64',
|
||||
media_type: mimeType,
|
||||
data: base64Data,
|
||||
},
|
||||
},
|
||||
];
|
||||
}
|
||||
});
|
||||
|
||||
return { messages: mergedMessages, systemPrompt: systemPrompt.trim() };
|
||||
if (!useTools) {
|
||||
mergedMessages.forEach((message) => {
|
||||
message.content.forEach((content) => {
|
||||
if (content.type === 'tool_use') {
|
||||
content.type = 'text';
|
||||
content.text = JSON.stringify(content.input);
|
||||
delete content.id;
|
||||
delete content.name;
|
||||
delete content.input;
|
||||
}
|
||||
if (content.type === 'tool_result') {
|
||||
content.type = 'text';
|
||||
content.text = content.content;
|
||||
delete content.tool_use_id;
|
||||
delete content.content;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
return { messages: mergedMessages, systemPrompt: systemPrompt };
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -205,7 +285,6 @@ function convertCohereMessages(messages, charName = '', userName = '') {
|
||||
'user': 'USER',
|
||||
'assistant': 'CHATBOT',
|
||||
};
|
||||
const placeholder = '[Start a new chat]';
|
||||
let systemPrompt = '';
|
||||
|
||||
// Collect all the system messages up until the first instance of a non-system message, and then remove them from the messages array.
|
||||
@@ -233,12 +312,12 @@ function convertCohereMessages(messages, charName = '', userName = '') {
|
||||
if (messages.length === 0) {
|
||||
messages.unshift({
|
||||
role: 'user',
|
||||
content: placeholder,
|
||||
content: PROMPT_PLACEHOLDER,
|
||||
});
|
||||
}
|
||||
|
||||
const lastNonSystemMessageIndex = messages.findLastIndex(msg => msg.role === 'user' || msg.role === 'assistant');
|
||||
const userPrompt = messages.slice(lastNonSystemMessageIndex).map(msg => msg.content).join('\n\n') || placeholder;
|
||||
const userPrompt = messages.slice(lastNonSystemMessageIndex).map(msg => msg.content).join('\n\n') || PROMPT_PLACEHOLDER;
|
||||
|
||||
const chatHistory = messages.slice(0, lastNonSystemMessageIndex).map(msg => {
|
||||
return {
|
||||
@@ -414,7 +493,7 @@ function convertAI21Messages(messages, charName = '', userName = '') {
|
||||
if (messages.length === 0) {
|
||||
messages.unshift({
|
||||
role: 'user',
|
||||
content: '[Start a new chat]',
|
||||
content: PROMPT_PLACEHOLDER,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -466,8 +545,18 @@ function convertMistralMessages(messages, charName = '', userName = '') {
|
||||
lastMsg.prefix = true;
|
||||
}
|
||||
|
||||
const sanitizeToolId = (id) => crypto.createHash('sha512').update(id).digest('hex').slice(0, 9);
|
||||
|
||||
// Doesn't support completion names, so prepend if not already done by the frontend (e.g. for group chats).
|
||||
messages.forEach(msg => {
|
||||
if ('tool_calls' in msg && Array.isArray(msg.tool_calls)) {
|
||||
msg.tool_calls.forEach(tool => {
|
||||
tool.id = sanitizeToolId(tool.id);
|
||||
});
|
||||
}
|
||||
if ('tool_call_id' in msg && msg.role === 'tool') {
|
||||
msg.tool_call_id = sanitizeToolId(msg.tool_call_id);
|
||||
}
|
||||
if (msg.role === 'system' && msg.name === 'example_assistant') {
|
||||
if (charName && !msg.content.startsWith(`${charName}: `)) {
|
||||
msg.content = `${charName}: ${msg.content}`;
|
||||
@@ -488,6 +577,28 @@ function convertMistralMessages(messages, charName = '', userName = '') {
|
||||
}
|
||||
});
|
||||
|
||||
// If user role message immediately follows a tool message, append it to the last user message
|
||||
const fixToolMessages = () => {
|
||||
let rerun = true;
|
||||
while (rerun) {
|
||||
rerun = false;
|
||||
messages.forEach((message, i) => {
|
||||
if (i === messages.length - 1) {
|
||||
return;
|
||||
}
|
||||
if (message.role === 'tool' && messages[i + 1].role === 'user') {
|
||||
const lastUserMessage = messages.slice(0, i).findLastIndex(m => m.role === 'user' && m.content);
|
||||
if (lastUserMessage !== -1) {
|
||||
messages[lastUserMessage].content += '\n\n' + messages[i + 1].content;
|
||||
messages.splice(i + 1, 1);
|
||||
rerun = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
fixToolMessages();
|
||||
|
||||
// If system role message immediately follows an assistant message, change its role to user
|
||||
for (let i = 0; i < messages.length - 1; i++) {
|
||||
if (messages[i].role === 'assistant' && messages[i + 1].role === 'system') {
|
||||
@@ -498,6 +609,83 @@ function convertMistralMessages(messages, charName = '', userName = '') {
|
||||
return messages;
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge messages with the same consecutive role, removing names if they exist.
|
||||
* @param {any[]} messages Messages to merge
|
||||
* @param {string} charName Character name
|
||||
* @param {string} userName User name
|
||||
* @param {boolean} strict Enable strict mode: only allow one system message at the start, force user first message
|
||||
* @returns {any[]} Merged messages
|
||||
*/
|
||||
function mergeMessages(messages, charName, userName, strict) {
|
||||
let mergedMessages = [];
|
||||
|
||||
// Remove names from the messages
|
||||
messages.forEach((message) => {
|
||||
if (!message.content) {
|
||||
message.content = '';
|
||||
}
|
||||
if (message.role === 'system' && message.name === 'example_assistant') {
|
||||
if (charName && !message.content.startsWith(`${charName}: `)) {
|
||||
message.content = `${charName}: ${message.content}`;
|
||||
}
|
||||
}
|
||||
if (message.role === 'system' && message.name === 'example_user') {
|
||||
if (userName && !message.content.startsWith(`${userName}: `)) {
|
||||
message.content = `${userName}: ${message.content}`;
|
||||
}
|
||||
}
|
||||
if (message.name && message.role !== 'system') {
|
||||
if (!message.content.startsWith(`${message.name}: `)) {
|
||||
message.content = `${message.name}: ${message.content}`;
|
||||
}
|
||||
}
|
||||
if (message.role === 'tool') {
|
||||
message.role = 'user';
|
||||
}
|
||||
delete message.name;
|
||||
delete message.tool_calls;
|
||||
delete message.tool_call_id;
|
||||
});
|
||||
|
||||
// Squash consecutive messages with the same role
|
||||
messages.forEach((message) => {
|
||||
if (mergedMessages.length > 0 && mergedMessages[mergedMessages.length - 1].role === message.role && message.content) {
|
||||
mergedMessages[mergedMessages.length - 1].content += '\n\n' + message.content;
|
||||
} else {
|
||||
mergedMessages.push(message);
|
||||
}
|
||||
});
|
||||
|
||||
// Prevent erroring out if the messages array is empty.
|
||||
if (messages.length === 0) {
|
||||
messages.unshift({
|
||||
role: 'user',
|
||||
content: PROMPT_PLACEHOLDER,
|
||||
});
|
||||
}
|
||||
|
||||
if (strict) {
|
||||
for (let i = 0; i < mergedMessages.length; i++) {
|
||||
// Force mid-prompt system messages to be user messages
|
||||
if (i > 0 && mergedMessages[i].role === 'system') {
|
||||
mergedMessages[i].role = 'user';
|
||||
}
|
||||
}
|
||||
if (mergedMessages.length) {
|
||||
if (mergedMessages[0].role === 'system' && (mergedMessages.length === 1 || mergedMessages[1].role !== 'user')) {
|
||||
mergedMessages.splice(1, 0, { role: 'user', content: PROMPT_PLACEHOLDER });
|
||||
}
|
||||
else if (mergedMessages[0].role !== 'system' && mergedMessages[0].role !== 'user') {
|
||||
mergedMessages.unshift({ role: 'user', content: PROMPT_PLACEHOLDER });
|
||||
}
|
||||
}
|
||||
return mergeMessages(mergedMessages, charName, userName, false);
|
||||
}
|
||||
|
||||
return mergedMessages;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a prompt from the ChatML objects to the format used by Text Completion API.
|
||||
* @param {object[]} messages Array of messages
|
||||
@@ -523,76 +711,6 @@ function convertTextCompletionPrompt(messages) {
|
||||
return messageStrings.join('\n') + '\nassistant:';
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert OpenAI Chat Completion tools to the format used by Cohere.
|
||||
* @param {object[]} tools OpenAI Chat Completion tool definitions
|
||||
*/
|
||||
function convertCohereTools(tools) {
|
||||
if (!Array.isArray(tools) || tools.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const jsonSchemaToPythonTypes = {
|
||||
'string': 'str',
|
||||
'number': 'float',
|
||||
'integer': 'int',
|
||||
'boolean': 'bool',
|
||||
'array': 'list',
|
||||
'object': 'dict',
|
||||
};
|
||||
|
||||
const cohereTools = [];
|
||||
|
||||
for (const tool of tools) {
|
||||
if (tool?.type !== 'function') {
|
||||
console.log(`Unsupported tool type: ${tool.type}`);
|
||||
continue;
|
||||
}
|
||||
|
||||
const name = tool?.function?.name;
|
||||
const description = tool?.function?.description;
|
||||
const properties = tool?.function?.parameters?.properties;
|
||||
const required = tool?.function?.parameters?.required;
|
||||
const parameters = {};
|
||||
|
||||
if (!name) {
|
||||
console.log('Tool name is missing');
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!description) {
|
||||
console.log('Tool description is missing');
|
||||
}
|
||||
|
||||
if (!properties || typeof properties !== 'object') {
|
||||
console.log(`No properties found for tool: ${tool?.function?.name}`);
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const property in properties) {
|
||||
const parameterDefinition = properties[property];
|
||||
const description = parameterDefinition.description || (parameterDefinition.enum ? JSON.stringify(parameterDefinition.enum) : '');
|
||||
const type = jsonSchemaToPythonTypes[parameterDefinition.type] || 'str';
|
||||
const isRequired = Array.isArray(required) && required.includes(property);
|
||||
parameters[property] = {
|
||||
description: description,
|
||||
type: type,
|
||||
required: isRequired,
|
||||
};
|
||||
}
|
||||
|
||||
const cohereTool = {
|
||||
name: tool.function.name,
|
||||
description: tool.function.description,
|
||||
parameter_definitions: parameters,
|
||||
};
|
||||
|
||||
cohereTools.push(cohereTool);
|
||||
}
|
||||
|
||||
return cohereTools;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
convertClaudePrompt,
|
||||
convertClaudeMessages,
|
||||
@@ -600,6 +718,6 @@ module.exports = {
|
||||
convertTextCompletionPrompt,
|
||||
convertCohereMessages,
|
||||
convertMistralMessages,
|
||||
convertCohereTools,
|
||||
convertAI21Messages,
|
||||
mergeMessages,
|
||||
};
|
||||
|
Reference in New Issue
Block a user