diff --git a/public/scripts/group-chats.js b/public/scripts/group-chats.js index 1152b2471..86e36bead 100644 --- a/public/scripts/group-chats.js +++ b/public/scripts/group-chats.js @@ -275,6 +275,17 @@ export function getGroupMembers(groupId = selected_group) { return group?.members.map(member => characters.find(x => x.avatar === member)) ?? []; } +/** + * Retrieves the member names of a group. If the group is not selected, an empty array is returned. + * @returns {string[]} An array of character names representing the members of the group. + */ +export function getGroupNames() { + const groupMembers = selected_group ? groups.find(x => x.id == selected_group)?.members : null; + return Array.isArray(groupMembers) + ? groupMembers.map(x => characters.find(y => y.avatar === x)?.name).filter(x => x) + : []; +} + /** * Finds the character ID for a group member. * @param {string} arg 0-based member index or character name diff --git a/public/scripts/openai.js b/public/scripts/openai.js index 8edc3c682..914d6d24a 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -33,7 +33,7 @@ import { system_message_types, this_chid, } from '../script.js'; -import { groups, selected_group } from './group-chats.js'; +import { getGroupNames, selected_group } from './group-chats.js'; import { chatCompletionDefaultPrompts, @@ -543,10 +543,7 @@ function setupChatCompletionPromptManager(openAiSettings) { * @returns {Message[]} Array of message objects */ export function parseExampleIntoIndividual(messageExampleString, appendNamesForGroup = true) { - const groupMembers = selected_group ? groups.find(x => x.id == selected_group)?.members : null; - const groupBotNames = Array.isArray(groupMembers) - ? groupMembers.map(x => characters.find(y => y.avatar === x)?.name).filter(x => x).map(x => `${x}:`) - : []; + const groupBotNames = getGroupNames().map(name => `${name}:`); let result = []; // array of msgs let tmp = messageExampleString.split('\n'); @@ -1877,6 +1874,7 @@ async function sendOpenAIRequest(type, messages, signal) { 'n': canMultiSwipe ? oai_settings.n : undefined, 'user_name': name1, 'char_name': name2, + 'group_names': getGroupNames(), }; // Empty array will produce a validation error diff --git a/src/endpoints/backends/chat-completions.js b/src/endpoints/backends/chat-completions.js index 519371ad4..64e6ab345 100644 --- a/src/endpoints/backends/chat-completions.js +++ b/src/endpoints/backends/chat-completions.js @@ -27,6 +27,7 @@ import { mergeMessages, cachingAtDepthForOpenRouterClaude, cachingAtDepthForClaude, + getPromptNames, } from '../../prompt-converters.js'; import { readSecret, SECRET_KEYS } from '../secrets.js'; @@ -55,17 +56,16 @@ const API_NANOGPT = 'https://nano-gpt.com/api/v1'; * Applies a post-processing step to the generated messages. * @param {object[]} messages Messages to post-process * @param {string} type Prompt conversion type - * @param {string} charName Character name - * @param {string} userName User name + * @param {import('../../prompt-converters.js').PromptNames} names Prompt names * @returns */ -function postProcessPrompt(messages, type, charName, userName) { +function postProcessPrompt(messages, type, names) { switch (type) { case 'merge': case 'claude': - return mergeMessages(messages, charName, userName, false); + return mergeMessages(messages, names, false); case 'strict': - return mergeMessages(messages, charName, userName, true); + return mergeMessages(messages, names, true); default: return messages; } @@ -101,7 +101,7 @@ async function sendClaudeRequest(request, response) { const additionalHeaders = {}; const useTools = request.body.model.startsWith('claude-3') && Array.isArray(request.body.tools) && request.body.tools.length > 0; const useSystemPrompt = (request.body.model.startsWith('claude-2') || request.body.model.startsWith('claude-3')) && request.body.claude_use_sysprompt; - const convertedPrompt = convertClaudeMessages(request.body.messages, request.body.assistant_prefill, useSystemPrompt, useTools, request.body.char_name, request.body.user_name); + const convertedPrompt = convertClaudeMessages(request.body.messages, request.body.assistant_prefill, useSystemPrompt, useTools, getPromptNames(request)); // Add custom stop sequences const stopSequences = []; if (Array.isArray(request.body.stop)) { @@ -282,9 +282,9 @@ async function sendMakerSuiteRequest(request, response) { model.includes('gemini-1.5-flash') || model.includes('gemini-1.5-pro') || model.startsWith('gemini-exp') - ) && request.body.use_makersuite_sysprompt; + ) && request.body.use_makersuite_sysprompt; - const prompt = convertGooglePrompt(request.body.messages, model, should_use_system_prompt, request.body.char_name, request.body.user_name); + const prompt = convertGooglePrompt(request.body.messages, model, should_use_system_prompt, getPromptNames(request)); let body = { contents: prompt.contents, safetySettings: GEMINI_SAFETY, @@ -384,7 +384,7 @@ async function sendAI21Request(request, response) { request.socket.on('close', function () { controller.abort(); }); - const convertedPrompt = convertAI21Messages(request.body.messages, request.body.char_name, request.body.user_name); + const convertedPrompt = convertAI21Messages(request.body.messages, getPromptNames(request)); const body = { messages: convertedPrompt, model: request.body.model, @@ -447,7 +447,7 @@ async function sendMistralAIRequest(request, response) { } try { - const messages = convertMistralMessages(request.body.messages, request.body.char_name, request.body.user_name); + const messages = convertMistralMessages(request.body.messages, getPromptNames(request)); const controller = new AbortController(); request.socket.removeAllListeners('close'); request.socket.on('close', function () { @@ -528,7 +528,7 @@ async function sendCohereRequest(request, response) { } try { - const convertedHistory = convertCohereMessages(request.body.messages, request.body.char_name, request.body.user_name); + const convertedHistory = convertCohereMessages(request.body.messages, getPromptNames(request)); const tools = []; if (Array.isArray(request.body.tools) && request.body.tools.length > 0) { @@ -886,15 +886,14 @@ router.post('/generate', jsonParser, function (request, response) { request.body.messages = postProcessPrompt( request.body.messages, request.body.custom_prompt_post_processing, - request.body.char_name, - request.body.user_name); + getPromptNames(request)); } } else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.PERPLEXITY) { apiUrl = API_PERPLEXITY; apiKey = readSecret(request.user.directories, SECRET_KEYS.PERPLEXITY); headers = {}; bodyParams = {}; - request.body.messages = postProcessPrompt(request.body.messages, 'strict', request.body.char_name, request.body.user_name); + request.body.messages = postProcessPrompt(request.body.messages, 'strict', getPromptNames(request)); } else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.GROQ) { apiUrl = API_GROQ; apiKey = readSecret(request.user.directories, SECRET_KEYS.GROQ); diff --git a/src/prompt-converters.js b/src/prompt-converters.js index 47f5989cc..cdee18d48 100644 --- a/src/prompt-converters.js +++ b/src/prompt-converters.js @@ -3,6 +3,30 @@ import { getConfigValue } from './util.js'; const PROMPT_PLACEHOLDER = getConfigValue('promptPlaceholder', 'Let\'s get started.'); +/** + * @typedef {object} PromptNames + * @property {string} charName Character name + * @property {string} userName User name + * @property {string[]} groupNames Group member names + * @property {function(string): boolean} startsFromGroupName Check if a message starts with a group name + */ + +/** + * Extracts the character name, user name, and group member names from the request. + * @param {import('express').Request} request Express request object + * @returns {PromptNames} Prompt names + */ +export function getPromptNames(request) { + return { + charName: String(request.body.char_name || ''), + userName: String(request.body.user_name || ''), + groupNames: Array.isArray(request.body.group_names) ? request.body.group_names.map(String) : [], + startsFromGroupName: function (message) { + return this.groupNames.some(name => message.startsWith(`${name}: `)); + }, + }; +} + /** * Convert a prompt from the ChatML objects to the format used by Claude. * Mainly deprecated. Only used for counting tokens. @@ -91,10 +115,10 @@ export function convertClaudePrompt(messages, addAssistantPostfix, addAssistantP * @param {string} prefillString User determined prefill string * @param {boolean} useSysPrompt See if we want to use a system prompt * @param {boolean} useTools See if we want to use tools - * @param {string} charName Character name - * @param {string} userName User name + * @param {PromptNames} names Prompt names + * @returns {{messages: object[], systemPrompt: object[]}} Prompt for Anthropic */ -export function convertClaudeMessages(messages, prefillString, useSysPrompt, useTools, charName, userName) { +export function convertClaudeMessages(messages, prefillString, useSysPrompt, useTools, names) { let systemPrompt = []; if (useSysPrompt) { // Collect all the system messages up until the first instance of a non-system message, and then remove them from the messages array. @@ -104,14 +128,14 @@ export function convertClaudeMessages(messages, prefillString, useSysPrompt, use break; } // Append example names if not already done by the frontend (e.g. for group chats). - if (userName && messages[i].name === 'example_user') { - if (!messages[i].content.startsWith(`${userName}: `)) { - messages[i].content = `${userName}: ${messages[i].content}`; + if (names.userName && messages[i].name === 'example_user') { + if (!messages[i].content.startsWith(`${names.userName}: `)) { + messages[i].content = `${names.userName}: ${messages[i].content}`; } } - if (charName && messages[i].name === 'example_assistant') { - if (!messages[i].content.startsWith(`${charName}: `)) { - messages[i].content = `${charName}: ${messages[i].content}`; + if (names.charName && messages[i].name === 'example_assistant') { + if (!messages[i].content.startsWith(`${names.charName}: `) && !names.startsFromGroupName(messages[i].content)) { + messages[i].content = `${names.charName}: ${messages[i].content}`; } } systemPrompt.push({ type: 'text', text: messages[i].content }); @@ -151,11 +175,15 @@ export function convertClaudeMessages(messages, prefillString, useSysPrompt, use } if (message.role === 'system') { - if (userName && message.name === 'example_user') { - message.content = `${userName}: ${message.content}`; + if (names.userName && message.name === 'example_user') { + if (!message.content.startsWith(`${names.userName}: `)) { + message.content = `${names.userName}: ${message.content}`; + } } - if (charName && message.name === 'example_assistant') { - message.content = `${charName}: ${message.content}`; + if (names.charName && message.name === 'example_assistant') { + if (!message.content.startsWith(`${names.charName}: `) && !names.startsFromGroupName(message.content)) { + message.content = `${names.charName}: ${message.content}`; + } } message.role = 'user'; @@ -274,11 +302,10 @@ export function convertClaudeMessages(messages, prefillString, useSysPrompt, use /** * Convert a prompt from the ChatML objects to the format used by Cohere. * @param {object[]} messages Array of messages - * @param {string} charName Character name - * @param {string} userName User name + * @param {PromptNames} names Prompt names * @returns {{chatHistory: object[]}} Prompt for Cohere */ -export function convertCohereMessages(messages, charName = '', userName = '') { +export function convertCohereMessages(messages, names) { if (messages.length === 0) { messages.unshift({ role: 'user', @@ -299,13 +326,13 @@ export function convertCohereMessages(messages, charName = '', userName = '') { // No names support (who would've thought) if (msg.name) { if (msg.role == 'system' && msg.name == 'example_assistant') { - if (charName && !msg.content.startsWith(`${charName}: `)) { - msg.content = `${charName}: ${msg.content}`; + if (names.charName && !msg.content.startsWith(`${names.charName}: `) && !names.startsFromGroupName(msg.content)) { + msg.content = `${names.charName}: ${msg.content}`; } } if (msg.role == 'system' && msg.name == 'example_user') { - if (userName && !msg.content.startsWith(`${userName}: `)) { - msg.content = `${userName}: ${msg.content}`; + if (names.userName && !msg.content.startsWith(`${names.userName}: `)) { + msg.content = `${names.userName}: ${msg.content}`; } } if (msg.role !== 'system' && !msg.content.startsWith(`${msg.name}: `)) { @@ -328,12 +355,10 @@ export function convertCohereMessages(messages, charName = '', userName = '') { * @param {object[]} messages Array of messages * @param {string} model Model name * @param {boolean} useSysPrompt Use system prompt - * @param {string} charName Character name - * @param {string} userName User name + * @param {PromptNames} names Prompt names * @returns {{contents: *[], system_instruction: {parts: {text: string}}}} Prompt for Google MakerSuite models */ -export function convertGooglePrompt(messages, model, useSysPrompt = false, charName = '', userName = '') { - +export function convertGooglePrompt(messages, model, useSysPrompt, names) { const visionSupportedModels = [ 'gemini-2.0-flash-exp', 'gemini-1.5-flash', @@ -356,20 +381,19 @@ export function convertGooglePrompt(messages, model, useSysPrompt = false, charN ]; const isMultimodal = visionSupportedModels.includes(model); - let hasImage = false; let sys_prompt = ''; if (useSysPrompt) { while (messages.length > 1 && messages[0].role === 'system') { // Append example names if not already done by the frontend (e.g. for group chats). - if (userName && messages[0].name === 'example_user') { - if (!messages[0].content.startsWith(`${userName}: `)) { - messages[0].content = `${userName}: ${messages[0].content}`; + if (names.userName && messages[0].name === 'example_user') { + if (!messages[0].content.startsWith(`${names.userName}: `)) { + messages[0].content = `${names.userName}: ${messages[0].content}`; } } - if (charName && messages[0].name === 'example_assistant') { - if (!messages[0].content.startsWith(`${charName}: `)) { - messages[0].content = `${charName}: ${messages[0].content}`; + if (names.charName && messages[0].name === 'example_assistant') { + if (!messages[0].content.startsWith(`${names.charName}: `) && !names.startsFromGroupName(messages[0].content)) { + messages[0].content = `${names.charName}: ${messages[0].content}`; } } sys_prompt += `${messages[0].content}\n\n`; @@ -388,53 +412,62 @@ export function convertGooglePrompt(messages, model, useSysPrompt = false, charN message.role = 'model'; } + // Convert the content to an array of parts + if (!Array.isArray(message.content)) { + message.content = [{ type: 'text', text: String(message.content ?? '') }]; + } + // similar story as claude if (message.name) { - if (userName && message.name === 'example_user') { - message.name = userName; - } - if (charName && message.name === 'example_assistant') { - message.name = charName; - } - - if (Array.isArray(message.content)) { - if (!message.content[0].text.startsWith(`${message.name}: `)) { - message.content[0].text = `${message.name}: ${message.content[0].text}`; + message.content.forEach((part) => { + if (part.type !== 'text') { + return; } - } else { - if (!message.content.startsWith(`${message.name}: `)) { - message.content = `${message.name}: ${message.content}`; + if (message.name === 'example_user') { + if (!part.text.startsWith(`${names.userName}: `)) { + part.text = `${names.userName}: ${part.text}`; + } + } else if (message.name === 'example_assistant') { + if (!part.text.startsWith(`${names.charName}: `) && !names.startsFromGroupName(part.text)) { + part.text = `${names.charName}: ${part.text}`; + } + } else { + if (!part.text.startsWith(`${message.name}: `)) { + part.text = `${message.name}: ${part.text}`; + } } - } + }); delete message.name; } //create the prompt parts const parts = []; - if (typeof message.content === 'string') { - parts.push({ text: message.content }); - } else if (Array.isArray(message.content)) { - message.content.forEach((part) => { - if (part.type === 'text') { - parts.push({ text: part.text }); - } else if (part.type === 'image_url' && isMultimodal) { - const mimeType = part.image_url.url.split(';')[0].split(':')[1]; - const base64Data = part.image_url.url.split(',')[1]; - parts.push({ - inlineData: { - mimeType: mimeType, - data: base64Data, - }, - }); - hasImage = true; - } - }); - } + message.content.forEach((part) => { + if (part.type === 'text') { + parts.push({ text: part.text }); + } else if (part.type === 'image_url' && isMultimodal) { + const mimeType = part.image_url.url.split(';')[0].split(':')[1]; + const base64Data = part.image_url.url.split(',')[1]; + parts.push({ + inlineData: { + mimeType: mimeType, + data: base64Data, + }, + }); + } + }); // merge consecutive messages with the same role if (index > 0 && message.role === contents[contents.length - 1].role) { - contents[contents.length - 1].parts[0].text += '\n\n' + parts[0].text; + parts.forEach((part) => { + if (part.text) { + contents[contents.length - 1].parts[0].text += '\n\n' + part.text; + } + if (part.inlineData) { + contents[contents.length - 1].parts.push(part); + } + }); } else { contents.push({ role: message.role, @@ -449,10 +482,10 @@ export function convertGooglePrompt(messages, model, useSysPrompt = false, charN /** * Convert AI21 prompt. Classic: system message squash, user/assistant message merge. * @param {object[]} messages Array of messages - * @param {string} charName Character name - * @param {string} userName User name + * @param {PromptNames} names Prompt names + * @returns {object[]} Prompt for AI21 */ -export function convertAI21Messages(messages, charName = '', userName = '') { +export function convertAI21Messages(messages, names) { if (!Array.isArray(messages)) { return []; } @@ -465,14 +498,14 @@ export function convertAI21Messages(messages, charName = '', userName = '') { break; } // Append example names if not already done by the frontend (e.g. for group chats). - if (userName && messages[i].name === 'example_user') { - if (!messages[i].content.startsWith(`${userName}: `)) { - messages[i].content = `${userName}: ${messages[i].content}`; + if (names.userName && messages[i].name === 'example_user') { + if (!messages[i].content.startsWith(`${names.userName}: `)) { + messages[i].content = `${names.userName}: ${messages[i].content}`; } } - if (charName && messages[i].name === 'example_assistant') { - if (!messages[i].content.startsWith(`${charName}: `)) { - messages[i].content = `${charName}: ${messages[i].content}`; + if (names.charName && messages[i].name === 'example_assistant') { + if (!messages[i].content.startsWith(`${names.charName}: `) && !names.startsFromGroupName(messages[i].content)) { + messages[i].content = `${names.charName}: ${messages[i].content}`; } } systemPrompt += `${messages[i].content}\n\n`; @@ -521,10 +554,10 @@ export function convertAI21Messages(messages, charName = '', userName = '') { /** * Convert a prompt from the ChatML objects to the format used by MistralAI. * @param {object[]} messages Array of messages - * @param {string} charName Character name - * @param {string} userName User name + * @param {PromptNames} names Prompt names + * @returns {object[]} Prompt for MistralAI */ -export function convertMistralMessages(messages, charName = '', userName = '') { +export function convertMistralMessages(messages, names) { if (!Array.isArray(messages)) { return []; } @@ -549,15 +582,15 @@ export function convertMistralMessages(messages, charName = '', userName = '') { msg.tool_call_id = sanitizeToolId(msg.tool_call_id); } if (msg.role === 'system' && msg.name === 'example_assistant') { - if (charName && !msg.content.startsWith(`${charName}: `)) { - msg.content = `${charName}: ${msg.content}`; + if (names.charName && !msg.content.startsWith(`${names.charName}: `) && !names.startsFromGroupName(msg.content)) { + msg.content = `${names.charName}: ${msg.content}`; } delete msg.name; } if (msg.role === 'system' && msg.name === 'example_user') { - if (userName && !msg.content.startsWith(`${userName}: `)) { - msg.content = `${userName}: ${msg.content}`; + if (names.userName && !msg.content.startsWith(`${names.userName}: `)) { + msg.content = `${names.userName}: ${msg.content}`; } delete msg.name; } @@ -603,12 +636,11 @@ export function convertMistralMessages(messages, charName = '', userName = '') { /** * Merge messages with the same consecutive role, removing names if they exist. * @param {any[]} messages Messages to merge - * @param {string} charName Character name - * @param {string} userName User name + * @param {PromptNames} names Prompt names * @param {boolean} strict Enable strict mode: only allow one system message at the start, force user first message * @returns {any[]} Merged messages */ -export function mergeMessages(messages, charName, userName, strict) { +export function mergeMessages(messages, names, strict) { let mergedMessages = []; /** @type {Map} */ @@ -636,13 +668,13 @@ export function mergeMessages(messages, charName, userName, strict) { message.content = text; } if (message.role === 'system' && message.name === 'example_assistant') { - if (charName && !message.content.startsWith(`${charName}: `)) { - message.content = `${charName}: ${message.content}`; + if (names.charName && !message.content.startsWith(`${names.charName}: `) && !names.startsFromGroupName(message.content)) { + message.content = `${names.charName}: ${message.content}`; } } if (message.role === 'system' && message.name === 'example_user') { - if (userName && !message.content.startsWith(`${userName}: `)) { - message.content = `${userName}: ${message.content}`; + if (names.userName && !message.content.startsWith(`${names.userName}: `)) { + message.content = `${names.userName}: ${message.content}`; } } if (message.name && message.role !== 'system') { @@ -716,7 +748,7 @@ export function mergeMessages(messages, charName, userName, strict) { mergedMessages.unshift({ role: 'user', content: PROMPT_PLACEHOLDER }); } } - return mergeMessages(mergedMessages, charName, userName, false); + return mergeMessages(mergedMessages, names, false); } return mergedMessages;