Refactor prompt converters with group names awareness

This commit is contained in:
Cohee
2024-12-20 23:30:57 +02:00
parent d7328af4c8
commit 73614f2f8d
4 changed files with 148 additions and 108 deletions

View File

@ -275,6 +275,17 @@ export function getGroupMembers(groupId = selected_group) {
return group?.members.map(member => characters.find(x => x.avatar === member)) ?? [];
}
/**
* Retrieves the member names of a group. If the group is not selected, an empty array is returned.
* @returns {string[]} An array of character names representing the members of the group.
*/
export function getGroupNames() {
const groupMembers = selected_group ? groups.find(x => x.id == selected_group)?.members : null;
return Array.isArray(groupMembers)
? groupMembers.map(x => characters.find(y => y.avatar === x)?.name).filter(x => x)
: [];
}
/**
* Finds the character ID for a group member.
* @param {string} arg 0-based member index or character name

View File

@ -33,7 +33,7 @@ import {
system_message_types,
this_chid,
} from '../script.js';
import { groups, selected_group } from './group-chats.js';
import { getGroupNames, selected_group } from './group-chats.js';
import {
chatCompletionDefaultPrompts,
@ -543,10 +543,7 @@ function setupChatCompletionPromptManager(openAiSettings) {
* @returns {Message[]} Array of message objects
*/
export function parseExampleIntoIndividual(messageExampleString, appendNamesForGroup = true) {
const groupMembers = selected_group ? groups.find(x => x.id == selected_group)?.members : null;
const groupBotNames = Array.isArray(groupMembers)
? groupMembers.map(x => characters.find(y => y.avatar === x)?.name).filter(x => x).map(x => `${x}:`)
: [];
const groupBotNames = getGroupNames().map(name => `${name}:`);
let result = []; // array of msgs
let tmp = messageExampleString.split('\n');
@ -1877,6 +1874,7 @@ async function sendOpenAIRequest(type, messages, signal) {
'n': canMultiSwipe ? oai_settings.n : undefined,
'user_name': name1,
'char_name': name2,
'group_names': getGroupNames(),
};
// Empty array will produce a validation error

View File

@ -27,6 +27,7 @@ import {
mergeMessages,
cachingAtDepthForOpenRouterClaude,
cachingAtDepthForClaude,
getPromptNames,
} from '../../prompt-converters.js';
import { readSecret, SECRET_KEYS } from '../secrets.js';
@ -55,17 +56,16 @@ const API_NANOGPT = 'https://nano-gpt.com/api/v1';
* Applies a post-processing step to the generated messages.
* @param {object[]} messages Messages to post-process
* @param {string} type Prompt conversion type
* @param {string} charName Character name
* @param {string} userName User name
* @param {import('../../prompt-converters.js').PromptNames} names Prompt names
* @returns
*/
function postProcessPrompt(messages, type, charName, userName) {
function postProcessPrompt(messages, type, names) {
switch (type) {
case 'merge':
case 'claude':
return mergeMessages(messages, charName, userName, false);
return mergeMessages(messages, names, false);
case 'strict':
return mergeMessages(messages, charName, userName, true);
return mergeMessages(messages, names, true);
default:
return messages;
}
@ -101,7 +101,7 @@ async function sendClaudeRequest(request, response) {
const additionalHeaders = {};
const useTools = request.body.model.startsWith('claude-3') && Array.isArray(request.body.tools) && request.body.tools.length > 0;
const useSystemPrompt = (request.body.model.startsWith('claude-2') || request.body.model.startsWith('claude-3')) && request.body.claude_use_sysprompt;
const convertedPrompt = convertClaudeMessages(request.body.messages, request.body.assistant_prefill, useSystemPrompt, useTools, request.body.char_name, request.body.user_name);
const convertedPrompt = convertClaudeMessages(request.body.messages, request.body.assistant_prefill, useSystemPrompt, useTools, getPromptNames(request));
// Add custom stop sequences
const stopSequences = [];
if (Array.isArray(request.body.stop)) {
@ -284,7 +284,7 @@ async function sendMakerSuiteRequest(request, response) {
model.startsWith('gemini-exp')
) && request.body.use_makersuite_sysprompt;
const prompt = convertGooglePrompt(request.body.messages, model, should_use_system_prompt, request.body.char_name, request.body.user_name);
const prompt = convertGooglePrompt(request.body.messages, model, should_use_system_prompt, getPromptNames(request));
let body = {
contents: prompt.contents,
safetySettings: GEMINI_SAFETY,
@ -384,7 +384,7 @@ async function sendAI21Request(request, response) {
request.socket.on('close', function () {
controller.abort();
});
const convertedPrompt = convertAI21Messages(request.body.messages, request.body.char_name, request.body.user_name);
const convertedPrompt = convertAI21Messages(request.body.messages, getPromptNames(request));
const body = {
messages: convertedPrompt,
model: request.body.model,
@ -447,7 +447,7 @@ async function sendMistralAIRequest(request, response) {
}
try {
const messages = convertMistralMessages(request.body.messages, request.body.char_name, request.body.user_name);
const messages = convertMistralMessages(request.body.messages, getPromptNames(request));
const controller = new AbortController();
request.socket.removeAllListeners('close');
request.socket.on('close', function () {
@ -528,7 +528,7 @@ async function sendCohereRequest(request, response) {
}
try {
const convertedHistory = convertCohereMessages(request.body.messages, request.body.char_name, request.body.user_name);
const convertedHistory = convertCohereMessages(request.body.messages, getPromptNames(request));
const tools = [];
if (Array.isArray(request.body.tools) && request.body.tools.length > 0) {
@ -886,15 +886,14 @@ router.post('/generate', jsonParser, function (request, response) {
request.body.messages = postProcessPrompt(
request.body.messages,
request.body.custom_prompt_post_processing,
request.body.char_name,
request.body.user_name);
getPromptNames(request));
}
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.PERPLEXITY) {
apiUrl = API_PERPLEXITY;
apiKey = readSecret(request.user.directories, SECRET_KEYS.PERPLEXITY);
headers = {};
bodyParams = {};
request.body.messages = postProcessPrompt(request.body.messages, 'strict', request.body.char_name, request.body.user_name);
request.body.messages = postProcessPrompt(request.body.messages, 'strict', getPromptNames(request));
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.GROQ) {
apiUrl = API_GROQ;
apiKey = readSecret(request.user.directories, SECRET_KEYS.GROQ);

View File

@ -3,6 +3,30 @@ import { getConfigValue } from './util.js';
const PROMPT_PLACEHOLDER = getConfigValue('promptPlaceholder', 'Let\'s get started.');
/**
* @typedef {object} PromptNames
* @property {string} charName Character name
* @property {string} userName User name
* @property {string[]} groupNames Group member names
* @property {function(string): boolean} startsFromGroupName Check if a message starts with a group name
*/
/**
* Extracts the character name, user name, and group member names from the request.
* @param {import('express').Request} request Express request object
* @returns {PromptNames} Prompt names
*/
export function getPromptNames(request) {
return {
charName: String(request.body.char_name || ''),
userName: String(request.body.user_name || ''),
groupNames: Array.isArray(request.body.group_names) ? request.body.group_names.map(String) : [],
startsFromGroupName: function (message) {
return this.groupNames.some(name => message.startsWith(`${name}: `));
},
};
}
/**
* Convert a prompt from the ChatML objects to the format used by Claude.
* Mainly deprecated. Only used for counting tokens.
@ -91,10 +115,10 @@ export function convertClaudePrompt(messages, addAssistantPostfix, addAssistantP
* @param {string} prefillString User determined prefill string
* @param {boolean} useSysPrompt See if we want to use a system prompt
* @param {boolean} useTools See if we want to use tools
* @param {string} charName Character name
* @param {string} userName User name
* @param {PromptNames} names Prompt names
* @returns {{messages: object[], systemPrompt: object[]}} Prompt for Anthropic
*/
export function convertClaudeMessages(messages, prefillString, useSysPrompt, useTools, charName, userName) {
export function convertClaudeMessages(messages, prefillString, useSysPrompt, useTools, names) {
let systemPrompt = [];
if (useSysPrompt) {
// Collect all the system messages up until the first instance of a non-system message, and then remove them from the messages array.
@ -104,14 +128,14 @@ export function convertClaudeMessages(messages, prefillString, useSysPrompt, use
break;
}
// Append example names if not already done by the frontend (e.g. for group chats).
if (userName && messages[i].name === 'example_user') {
if (!messages[i].content.startsWith(`${userName}: `)) {
messages[i].content = `${userName}: ${messages[i].content}`;
if (names.userName && messages[i].name === 'example_user') {
if (!messages[i].content.startsWith(`${names.userName}: `)) {
messages[i].content = `${names.userName}: ${messages[i].content}`;
}
}
if (charName && messages[i].name === 'example_assistant') {
if (!messages[i].content.startsWith(`${charName}: `)) {
messages[i].content = `${charName}: ${messages[i].content}`;
if (names.charName && messages[i].name === 'example_assistant') {
if (!messages[i].content.startsWith(`${names.charName}: `) && !names.startsFromGroupName(messages[i].content)) {
messages[i].content = `${names.charName}: ${messages[i].content}`;
}
}
systemPrompt.push({ type: 'text', text: messages[i].content });
@ -151,11 +175,15 @@ export function convertClaudeMessages(messages, prefillString, useSysPrompt, use
}
if (message.role === 'system') {
if (userName && message.name === 'example_user') {
message.content = `${userName}: ${message.content}`;
if (names.userName && message.name === 'example_user') {
if (!message.content.startsWith(`${names.userName}: `)) {
message.content = `${names.userName}: ${message.content}`;
}
}
if (names.charName && message.name === 'example_assistant') {
if (!message.content.startsWith(`${names.charName}: `) && !names.startsFromGroupName(message.content)) {
message.content = `${names.charName}: ${message.content}`;
}
if (charName && message.name === 'example_assistant') {
message.content = `${charName}: ${message.content}`;
}
message.role = 'user';
@ -274,11 +302,10 @@ export function convertClaudeMessages(messages, prefillString, useSysPrompt, use
/**
* Convert a prompt from the ChatML objects to the format used by Cohere.
* @param {object[]} messages Array of messages
* @param {string} charName Character name
* @param {string} userName User name
* @param {PromptNames} names Prompt names
* @returns {{chatHistory: object[]}} Prompt for Cohere
*/
export function convertCohereMessages(messages, charName = '', userName = '') {
export function convertCohereMessages(messages, names) {
if (messages.length === 0) {
messages.unshift({
role: 'user',
@ -299,13 +326,13 @@ export function convertCohereMessages(messages, charName = '', userName = '') {
// No names support (who would've thought)
if (msg.name) {
if (msg.role == 'system' && msg.name == 'example_assistant') {
if (charName && !msg.content.startsWith(`${charName}: `)) {
msg.content = `${charName}: ${msg.content}`;
if (names.charName && !msg.content.startsWith(`${names.charName}: `) && !names.startsFromGroupName(msg.content)) {
msg.content = `${names.charName}: ${msg.content}`;
}
}
if (msg.role == 'system' && msg.name == 'example_user') {
if (userName && !msg.content.startsWith(`${userName}: `)) {
msg.content = `${userName}: ${msg.content}`;
if (names.userName && !msg.content.startsWith(`${names.userName}: `)) {
msg.content = `${names.userName}: ${msg.content}`;
}
}
if (msg.role !== 'system' && !msg.content.startsWith(`${msg.name}: `)) {
@ -328,12 +355,10 @@ export function convertCohereMessages(messages, charName = '', userName = '') {
* @param {object[]} messages Array of messages
* @param {string} model Model name
* @param {boolean} useSysPrompt Use system prompt
* @param {string} charName Character name
* @param {string} userName User name
* @param {PromptNames} names Prompt names
* @returns {{contents: *[], system_instruction: {parts: {text: string}}}} Prompt for Google MakerSuite models
*/
export function convertGooglePrompt(messages, model, useSysPrompt = false, charName = '', userName = '') {
export function convertGooglePrompt(messages, model, useSysPrompt, names) {
const visionSupportedModels = [
'gemini-2.0-flash-exp',
'gemini-1.5-flash',
@ -356,20 +381,19 @@ export function convertGooglePrompt(messages, model, useSysPrompt = false, charN
];
const isMultimodal = visionSupportedModels.includes(model);
let hasImage = false;
let sys_prompt = '';
if (useSysPrompt) {
while (messages.length > 1 && messages[0].role === 'system') {
// Append example names if not already done by the frontend (e.g. for group chats).
if (userName && messages[0].name === 'example_user') {
if (!messages[0].content.startsWith(`${userName}: `)) {
messages[0].content = `${userName}: ${messages[0].content}`;
if (names.userName && messages[0].name === 'example_user') {
if (!messages[0].content.startsWith(`${names.userName}: `)) {
messages[0].content = `${names.userName}: ${messages[0].content}`;
}
}
if (charName && messages[0].name === 'example_assistant') {
if (!messages[0].content.startsWith(`${charName}: `)) {
messages[0].content = `${charName}: ${messages[0].content}`;
if (names.charName && messages[0].name === 'example_assistant') {
if (!messages[0].content.startsWith(`${names.charName}: `) && !names.startsFromGroupName(messages[0].content)) {
messages[0].content = `${names.charName}: ${messages[0].content}`;
}
}
sys_prompt += `${messages[0].content}\n\n`;
@ -388,33 +412,37 @@ export function convertGooglePrompt(messages, model, useSysPrompt = false, charN
message.role = 'model';
}
// similar story as claude
if (message.name) {
if (userName && message.name === 'example_user') {
message.name = userName;
}
if (charName && message.name === 'example_assistant') {
message.name = charName;
// Convert the content to an array of parts
if (!Array.isArray(message.content)) {
message.content = [{ type: 'text', text: String(message.content ?? '') }];
}
if (Array.isArray(message.content)) {
if (!message.content[0].text.startsWith(`${message.name}: `)) {
message.content[0].text = `${message.name}: ${message.content[0].text}`;
// similar story as claude
if (message.name) {
message.content.forEach((part) => {
if (part.type !== 'text') {
return;
}
if (message.name === 'example_user') {
if (!part.text.startsWith(`${names.userName}: `)) {
part.text = `${names.userName}: ${part.text}`;
}
} else if (message.name === 'example_assistant') {
if (!part.text.startsWith(`${names.charName}: `) && !names.startsFromGroupName(part.text)) {
part.text = `${names.charName}: ${part.text}`;
}
} else {
if (!message.content.startsWith(`${message.name}: `)) {
message.content = `${message.name}: ${message.content}`;
if (!part.text.startsWith(`${message.name}: `)) {
part.text = `${message.name}: ${part.text}`;
}
}
});
delete message.name;
}
//create the prompt parts
const parts = [];
if (typeof message.content === 'string') {
parts.push({ text: message.content });
} else if (Array.isArray(message.content)) {
message.content.forEach((part) => {
if (part.type === 'text') {
parts.push({ text: part.text });
@ -427,14 +455,19 @@ export function convertGooglePrompt(messages, model, useSysPrompt = false, charN
data: base64Data,
},
});
hasImage = true;
}
});
}
// merge consecutive messages with the same role
if (index > 0 && message.role === contents[contents.length - 1].role) {
contents[contents.length - 1].parts[0].text += '\n\n' + parts[0].text;
parts.forEach((part) => {
if (part.text) {
contents[contents.length - 1].parts[0].text += '\n\n' + part.text;
}
if (part.inlineData) {
contents[contents.length - 1].parts.push(part);
}
});
} else {
contents.push({
role: message.role,
@ -449,10 +482,10 @@ export function convertGooglePrompt(messages, model, useSysPrompt = false, charN
/**
* Convert AI21 prompt. Classic: system message squash, user/assistant message merge.
* @param {object[]} messages Array of messages
* @param {string} charName Character name
* @param {string} userName User name
* @param {PromptNames} names Prompt names
* @returns {object[]} Prompt for AI21
*/
export function convertAI21Messages(messages, charName = '', userName = '') {
export function convertAI21Messages(messages, names) {
if (!Array.isArray(messages)) {
return [];
}
@ -465,14 +498,14 @@ export function convertAI21Messages(messages, charName = '', userName = '') {
break;
}
// Append example names if not already done by the frontend (e.g. for group chats).
if (userName && messages[i].name === 'example_user') {
if (!messages[i].content.startsWith(`${userName}: `)) {
messages[i].content = `${userName}: ${messages[i].content}`;
if (names.userName && messages[i].name === 'example_user') {
if (!messages[i].content.startsWith(`${names.userName}: `)) {
messages[i].content = `${names.userName}: ${messages[i].content}`;
}
}
if (charName && messages[i].name === 'example_assistant') {
if (!messages[i].content.startsWith(`${charName}: `)) {
messages[i].content = `${charName}: ${messages[i].content}`;
if (names.charName && messages[i].name === 'example_assistant') {
if (!messages[i].content.startsWith(`${names.charName}: `) && !names.startsFromGroupName(messages[i].content)) {
messages[i].content = `${names.charName}: ${messages[i].content}`;
}
}
systemPrompt += `${messages[i].content}\n\n`;
@ -521,10 +554,10 @@ export function convertAI21Messages(messages, charName = '', userName = '') {
/**
* Convert a prompt from the ChatML objects to the format used by MistralAI.
* @param {object[]} messages Array of messages
* @param {string} charName Character name
* @param {string} userName User name
* @param {PromptNames} names Prompt names
* @returns {object[]} Prompt for MistralAI
*/
export function convertMistralMessages(messages, charName = '', userName = '') {
export function convertMistralMessages(messages, names) {
if (!Array.isArray(messages)) {
return [];
}
@ -549,15 +582,15 @@ export function convertMistralMessages(messages, charName = '', userName = '') {
msg.tool_call_id = sanitizeToolId(msg.tool_call_id);
}
if (msg.role === 'system' && msg.name === 'example_assistant') {
if (charName && !msg.content.startsWith(`${charName}: `)) {
msg.content = `${charName}: ${msg.content}`;
if (names.charName && !msg.content.startsWith(`${names.charName}: `) && !names.startsFromGroupName(msg.content)) {
msg.content = `${names.charName}: ${msg.content}`;
}
delete msg.name;
}
if (msg.role === 'system' && msg.name === 'example_user') {
if (userName && !msg.content.startsWith(`${userName}: `)) {
msg.content = `${userName}: ${msg.content}`;
if (names.userName && !msg.content.startsWith(`${names.userName}: `)) {
msg.content = `${names.userName}: ${msg.content}`;
}
delete msg.name;
}
@ -603,12 +636,11 @@ export function convertMistralMessages(messages, charName = '', userName = '') {
/**
* Merge messages with the same consecutive role, removing names if they exist.
* @param {any[]} messages Messages to merge
* @param {string} charName Character name
* @param {string} userName User name
* @param {PromptNames} names Prompt names
* @param {boolean} strict Enable strict mode: only allow one system message at the start, force user first message
* @returns {any[]} Merged messages
*/
export function mergeMessages(messages, charName, userName, strict) {
export function mergeMessages(messages, names, strict) {
let mergedMessages = [];
/** @type {Map<string,object>} */
@ -636,13 +668,13 @@ export function mergeMessages(messages, charName, userName, strict) {
message.content = text;
}
if (message.role === 'system' && message.name === 'example_assistant') {
if (charName && !message.content.startsWith(`${charName}: `)) {
message.content = `${charName}: ${message.content}`;
if (names.charName && !message.content.startsWith(`${names.charName}: `) && !names.startsFromGroupName(message.content)) {
message.content = `${names.charName}: ${message.content}`;
}
}
if (message.role === 'system' && message.name === 'example_user') {
if (userName && !message.content.startsWith(`${userName}: `)) {
message.content = `${userName}: ${message.content}`;
if (names.userName && !message.content.startsWith(`${names.userName}: `)) {
message.content = `${names.userName}: ${message.content}`;
}
}
if (message.name && message.role !== 'system') {
@ -716,7 +748,7 @@ export function mergeMessages(messages, charName, userName, strict) {
mergedMessages.unshift({ role: 'user', content: PROMPT_PLACEHOLDER });
}
}
return mergeMessages(mergedMessages, charName, userName, false);
return mergeMessages(mergedMessages, names, false);
}
return mergedMessages;