mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-06-05 21:59:27 +02:00
Refactor prompt converters with group names awareness
This commit is contained in:
@ -3,6 +3,30 @@ import { getConfigValue } from './util.js';
|
||||
|
||||
const PROMPT_PLACEHOLDER = getConfigValue('promptPlaceholder', 'Let\'s get started.');
|
||||
|
||||
/**
|
||||
* @typedef {object} PromptNames
|
||||
* @property {string} charName Character name
|
||||
* @property {string} userName User name
|
||||
* @property {string[]} groupNames Group member names
|
||||
* @property {function(string): boolean} startsFromGroupName Check if a message starts with a group name
|
||||
*/
|
||||
|
||||
/**
|
||||
* Extracts the character name, user name, and group member names from the request.
|
||||
* @param {import('express').Request} request Express request object
|
||||
* @returns {PromptNames} Prompt names
|
||||
*/
|
||||
export function getPromptNames(request) {
|
||||
return {
|
||||
charName: String(request.body.char_name || ''),
|
||||
userName: String(request.body.user_name || ''),
|
||||
groupNames: Array.isArray(request.body.group_names) ? request.body.group_names.map(String) : [],
|
||||
startsFromGroupName: function (message) {
|
||||
return this.groupNames.some(name => message.startsWith(`${name}: `));
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a prompt from the ChatML objects to the format used by Claude.
|
||||
* Mainly deprecated. Only used for counting tokens.
|
||||
@ -91,10 +115,10 @@ export function convertClaudePrompt(messages, addAssistantPostfix, addAssistantP
|
||||
* @param {string} prefillString User determined prefill string
|
||||
* @param {boolean} useSysPrompt See if we want to use a system prompt
|
||||
* @param {boolean} useTools See if we want to use tools
|
||||
* @param {string} charName Character name
|
||||
* @param {string} userName User name
|
||||
* @param {PromptNames} names Prompt names
|
||||
* @returns {{messages: object[], systemPrompt: object[]}} Prompt for Anthropic
|
||||
*/
|
||||
export function convertClaudeMessages(messages, prefillString, useSysPrompt, useTools, charName, userName) {
|
||||
export function convertClaudeMessages(messages, prefillString, useSysPrompt, useTools, names) {
|
||||
let systemPrompt = [];
|
||||
if (useSysPrompt) {
|
||||
// Collect all the system messages up until the first instance of a non-system message, and then remove them from the messages array.
|
||||
@ -104,14 +128,14 @@ export function convertClaudeMessages(messages, prefillString, useSysPrompt, use
|
||||
break;
|
||||
}
|
||||
// Append example names if not already done by the frontend (e.g. for group chats).
|
||||
if (userName && messages[i].name === 'example_user') {
|
||||
if (!messages[i].content.startsWith(`${userName}: `)) {
|
||||
messages[i].content = `${userName}: ${messages[i].content}`;
|
||||
if (names.userName && messages[i].name === 'example_user') {
|
||||
if (!messages[i].content.startsWith(`${names.userName}: `)) {
|
||||
messages[i].content = `${names.userName}: ${messages[i].content}`;
|
||||
}
|
||||
}
|
||||
if (charName && messages[i].name === 'example_assistant') {
|
||||
if (!messages[i].content.startsWith(`${charName}: `)) {
|
||||
messages[i].content = `${charName}: ${messages[i].content}`;
|
||||
if (names.charName && messages[i].name === 'example_assistant') {
|
||||
if (!messages[i].content.startsWith(`${names.charName}: `) && !names.startsFromGroupName(messages[i].content)) {
|
||||
messages[i].content = `${names.charName}: ${messages[i].content}`;
|
||||
}
|
||||
}
|
||||
systemPrompt.push({ type: 'text', text: messages[i].content });
|
||||
@ -151,11 +175,15 @@ export function convertClaudeMessages(messages, prefillString, useSysPrompt, use
|
||||
}
|
||||
|
||||
if (message.role === 'system') {
|
||||
if (userName && message.name === 'example_user') {
|
||||
message.content = `${userName}: ${message.content}`;
|
||||
if (names.userName && message.name === 'example_user') {
|
||||
if (!message.content.startsWith(`${names.userName}: `)) {
|
||||
message.content = `${names.userName}: ${message.content}`;
|
||||
}
|
||||
}
|
||||
if (charName && message.name === 'example_assistant') {
|
||||
message.content = `${charName}: ${message.content}`;
|
||||
if (names.charName && message.name === 'example_assistant') {
|
||||
if (!message.content.startsWith(`${names.charName}: `) && !names.startsFromGroupName(message.content)) {
|
||||
message.content = `${names.charName}: ${message.content}`;
|
||||
}
|
||||
}
|
||||
message.role = 'user';
|
||||
|
||||
@ -274,11 +302,10 @@ export function convertClaudeMessages(messages, prefillString, useSysPrompt, use
|
||||
/**
|
||||
* Convert a prompt from the ChatML objects to the format used by Cohere.
|
||||
* @param {object[]} messages Array of messages
|
||||
* @param {string} charName Character name
|
||||
* @param {string} userName User name
|
||||
* @param {PromptNames} names Prompt names
|
||||
* @returns {{chatHistory: object[]}} Prompt for Cohere
|
||||
*/
|
||||
export function convertCohereMessages(messages, charName = '', userName = '') {
|
||||
export function convertCohereMessages(messages, names) {
|
||||
if (messages.length === 0) {
|
||||
messages.unshift({
|
||||
role: 'user',
|
||||
@ -299,13 +326,13 @@ export function convertCohereMessages(messages, charName = '', userName = '') {
|
||||
// No names support (who would've thought)
|
||||
if (msg.name) {
|
||||
if (msg.role == 'system' && msg.name == 'example_assistant') {
|
||||
if (charName && !msg.content.startsWith(`${charName}: `)) {
|
||||
msg.content = `${charName}: ${msg.content}`;
|
||||
if (names.charName && !msg.content.startsWith(`${names.charName}: `) && !names.startsFromGroupName(msg.content)) {
|
||||
msg.content = `${names.charName}: ${msg.content}`;
|
||||
}
|
||||
}
|
||||
if (msg.role == 'system' && msg.name == 'example_user') {
|
||||
if (userName && !msg.content.startsWith(`${userName}: `)) {
|
||||
msg.content = `${userName}: ${msg.content}`;
|
||||
if (names.userName && !msg.content.startsWith(`${names.userName}: `)) {
|
||||
msg.content = `${names.userName}: ${msg.content}`;
|
||||
}
|
||||
}
|
||||
if (msg.role !== 'system' && !msg.content.startsWith(`${msg.name}: `)) {
|
||||
@ -328,12 +355,10 @@ export function convertCohereMessages(messages, charName = '', userName = '') {
|
||||
* @param {object[]} messages Array of messages
|
||||
* @param {string} model Model name
|
||||
* @param {boolean} useSysPrompt Use system prompt
|
||||
* @param {string} charName Character name
|
||||
* @param {string} userName User name
|
||||
* @param {PromptNames} names Prompt names
|
||||
* @returns {{contents: *[], system_instruction: {parts: {text: string}}}} Prompt for Google MakerSuite models
|
||||
*/
|
||||
export function convertGooglePrompt(messages, model, useSysPrompt = false, charName = '', userName = '') {
|
||||
|
||||
export function convertGooglePrompt(messages, model, useSysPrompt, names) {
|
||||
const visionSupportedModels = [
|
||||
'gemini-2.0-flash-exp',
|
||||
'gemini-1.5-flash',
|
||||
@ -356,20 +381,19 @@ export function convertGooglePrompt(messages, model, useSysPrompt = false, charN
|
||||
];
|
||||
|
||||
const isMultimodal = visionSupportedModels.includes(model);
|
||||
let hasImage = false;
|
||||
|
||||
let sys_prompt = '';
|
||||
if (useSysPrompt) {
|
||||
while (messages.length > 1 && messages[0].role === 'system') {
|
||||
// Append example names if not already done by the frontend (e.g. for group chats).
|
||||
if (userName && messages[0].name === 'example_user') {
|
||||
if (!messages[0].content.startsWith(`${userName}: `)) {
|
||||
messages[0].content = `${userName}: ${messages[0].content}`;
|
||||
if (names.userName && messages[0].name === 'example_user') {
|
||||
if (!messages[0].content.startsWith(`${names.userName}: `)) {
|
||||
messages[0].content = `${names.userName}: ${messages[0].content}`;
|
||||
}
|
||||
}
|
||||
if (charName && messages[0].name === 'example_assistant') {
|
||||
if (!messages[0].content.startsWith(`${charName}: `)) {
|
||||
messages[0].content = `${charName}: ${messages[0].content}`;
|
||||
if (names.charName && messages[0].name === 'example_assistant') {
|
||||
if (!messages[0].content.startsWith(`${names.charName}: `) && !names.startsFromGroupName(messages[0].content)) {
|
||||
messages[0].content = `${names.charName}: ${messages[0].content}`;
|
||||
}
|
||||
}
|
||||
sys_prompt += `${messages[0].content}\n\n`;
|
||||
@ -388,53 +412,62 @@ export function convertGooglePrompt(messages, model, useSysPrompt = false, charN
|
||||
message.role = 'model';
|
||||
}
|
||||
|
||||
// Convert the content to an array of parts
|
||||
if (!Array.isArray(message.content)) {
|
||||
message.content = [{ type: 'text', text: String(message.content ?? '') }];
|
||||
}
|
||||
|
||||
// similar story as claude
|
||||
if (message.name) {
|
||||
if (userName && message.name === 'example_user') {
|
||||
message.name = userName;
|
||||
}
|
||||
if (charName && message.name === 'example_assistant') {
|
||||
message.name = charName;
|
||||
}
|
||||
|
||||
if (Array.isArray(message.content)) {
|
||||
if (!message.content[0].text.startsWith(`${message.name}: `)) {
|
||||
message.content[0].text = `${message.name}: ${message.content[0].text}`;
|
||||
message.content.forEach((part) => {
|
||||
if (part.type !== 'text') {
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
if (!message.content.startsWith(`${message.name}: `)) {
|
||||
message.content = `${message.name}: ${message.content}`;
|
||||
if (message.name === 'example_user') {
|
||||
if (!part.text.startsWith(`${names.userName}: `)) {
|
||||
part.text = `${names.userName}: ${part.text}`;
|
||||
}
|
||||
} else if (message.name === 'example_assistant') {
|
||||
if (!part.text.startsWith(`${names.charName}: `) && !names.startsFromGroupName(part.text)) {
|
||||
part.text = `${names.charName}: ${part.text}`;
|
||||
}
|
||||
} else {
|
||||
if (!part.text.startsWith(`${message.name}: `)) {
|
||||
part.text = `${message.name}: ${part.text}`;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
delete message.name;
|
||||
}
|
||||
|
||||
//create the prompt parts
|
||||
const parts = [];
|
||||
if (typeof message.content === 'string') {
|
||||
parts.push({ text: message.content });
|
||||
} else if (Array.isArray(message.content)) {
|
||||
message.content.forEach((part) => {
|
||||
if (part.type === 'text') {
|
||||
parts.push({ text: part.text });
|
||||
} else if (part.type === 'image_url' && isMultimodal) {
|
||||
const mimeType = part.image_url.url.split(';')[0].split(':')[1];
|
||||
const base64Data = part.image_url.url.split(',')[1];
|
||||
parts.push({
|
||||
inlineData: {
|
||||
mimeType: mimeType,
|
||||
data: base64Data,
|
||||
},
|
||||
});
|
||||
hasImage = true;
|
||||
}
|
||||
});
|
||||
}
|
||||
message.content.forEach((part) => {
|
||||
if (part.type === 'text') {
|
||||
parts.push({ text: part.text });
|
||||
} else if (part.type === 'image_url' && isMultimodal) {
|
||||
const mimeType = part.image_url.url.split(';')[0].split(':')[1];
|
||||
const base64Data = part.image_url.url.split(',')[1];
|
||||
parts.push({
|
||||
inlineData: {
|
||||
mimeType: mimeType,
|
||||
data: base64Data,
|
||||
},
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// merge consecutive messages with the same role
|
||||
if (index > 0 && message.role === contents[contents.length - 1].role) {
|
||||
contents[contents.length - 1].parts[0].text += '\n\n' + parts[0].text;
|
||||
parts.forEach((part) => {
|
||||
if (part.text) {
|
||||
contents[contents.length - 1].parts[0].text += '\n\n' + part.text;
|
||||
}
|
||||
if (part.inlineData) {
|
||||
contents[contents.length - 1].parts.push(part);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
contents.push({
|
||||
role: message.role,
|
||||
@ -449,10 +482,10 @@ export function convertGooglePrompt(messages, model, useSysPrompt = false, charN
|
||||
/**
|
||||
* Convert AI21 prompt. Classic: system message squash, user/assistant message merge.
|
||||
* @param {object[]} messages Array of messages
|
||||
* @param {string} charName Character name
|
||||
* @param {string} userName User name
|
||||
* @param {PromptNames} names Prompt names
|
||||
* @returns {object[]} Prompt for AI21
|
||||
*/
|
||||
export function convertAI21Messages(messages, charName = '', userName = '') {
|
||||
export function convertAI21Messages(messages, names) {
|
||||
if (!Array.isArray(messages)) {
|
||||
return [];
|
||||
}
|
||||
@ -465,14 +498,14 @@ export function convertAI21Messages(messages, charName = '', userName = '') {
|
||||
break;
|
||||
}
|
||||
// Append example names if not already done by the frontend (e.g. for group chats).
|
||||
if (userName && messages[i].name === 'example_user') {
|
||||
if (!messages[i].content.startsWith(`${userName}: `)) {
|
||||
messages[i].content = `${userName}: ${messages[i].content}`;
|
||||
if (names.userName && messages[i].name === 'example_user') {
|
||||
if (!messages[i].content.startsWith(`${names.userName}: `)) {
|
||||
messages[i].content = `${names.userName}: ${messages[i].content}`;
|
||||
}
|
||||
}
|
||||
if (charName && messages[i].name === 'example_assistant') {
|
||||
if (!messages[i].content.startsWith(`${charName}: `)) {
|
||||
messages[i].content = `${charName}: ${messages[i].content}`;
|
||||
if (names.charName && messages[i].name === 'example_assistant') {
|
||||
if (!messages[i].content.startsWith(`${names.charName}: `) && !names.startsFromGroupName(messages[i].content)) {
|
||||
messages[i].content = `${names.charName}: ${messages[i].content}`;
|
||||
}
|
||||
}
|
||||
systemPrompt += `${messages[i].content}\n\n`;
|
||||
@ -521,10 +554,10 @@ export function convertAI21Messages(messages, charName = '', userName = '') {
|
||||
/**
|
||||
* Convert a prompt from the ChatML objects to the format used by MistralAI.
|
||||
* @param {object[]} messages Array of messages
|
||||
* @param {string} charName Character name
|
||||
* @param {string} userName User name
|
||||
* @param {PromptNames} names Prompt names
|
||||
* @returns {object[]} Prompt for MistralAI
|
||||
*/
|
||||
export function convertMistralMessages(messages, charName = '', userName = '') {
|
||||
export function convertMistralMessages(messages, names) {
|
||||
if (!Array.isArray(messages)) {
|
||||
return [];
|
||||
}
|
||||
@ -549,15 +582,15 @@ export function convertMistralMessages(messages, charName = '', userName = '') {
|
||||
msg.tool_call_id = sanitizeToolId(msg.tool_call_id);
|
||||
}
|
||||
if (msg.role === 'system' && msg.name === 'example_assistant') {
|
||||
if (charName && !msg.content.startsWith(`${charName}: `)) {
|
||||
msg.content = `${charName}: ${msg.content}`;
|
||||
if (names.charName && !msg.content.startsWith(`${names.charName}: `) && !names.startsFromGroupName(msg.content)) {
|
||||
msg.content = `${names.charName}: ${msg.content}`;
|
||||
}
|
||||
delete msg.name;
|
||||
}
|
||||
|
||||
if (msg.role === 'system' && msg.name === 'example_user') {
|
||||
if (userName && !msg.content.startsWith(`${userName}: `)) {
|
||||
msg.content = `${userName}: ${msg.content}`;
|
||||
if (names.userName && !msg.content.startsWith(`${names.userName}: `)) {
|
||||
msg.content = `${names.userName}: ${msg.content}`;
|
||||
}
|
||||
delete msg.name;
|
||||
}
|
||||
@ -603,12 +636,11 @@ export function convertMistralMessages(messages, charName = '', userName = '') {
|
||||
/**
|
||||
* Merge messages with the same consecutive role, removing names if they exist.
|
||||
* @param {any[]} messages Messages to merge
|
||||
* @param {string} charName Character name
|
||||
* @param {string} userName User name
|
||||
* @param {PromptNames} names Prompt names
|
||||
* @param {boolean} strict Enable strict mode: only allow one system message at the start, force user first message
|
||||
* @returns {any[]} Merged messages
|
||||
*/
|
||||
export function mergeMessages(messages, charName, userName, strict) {
|
||||
export function mergeMessages(messages, names, strict) {
|
||||
let mergedMessages = [];
|
||||
|
||||
/** @type {Map<string,object>} */
|
||||
@ -636,13 +668,13 @@ export function mergeMessages(messages, charName, userName, strict) {
|
||||
message.content = text;
|
||||
}
|
||||
if (message.role === 'system' && message.name === 'example_assistant') {
|
||||
if (charName && !message.content.startsWith(`${charName}: `)) {
|
||||
message.content = `${charName}: ${message.content}`;
|
||||
if (names.charName && !message.content.startsWith(`${names.charName}: `) && !names.startsFromGroupName(message.content)) {
|
||||
message.content = `${names.charName}: ${message.content}`;
|
||||
}
|
||||
}
|
||||
if (message.role === 'system' && message.name === 'example_user') {
|
||||
if (userName && !message.content.startsWith(`${userName}: `)) {
|
||||
message.content = `${userName}: ${message.content}`;
|
||||
if (names.userName && !message.content.startsWith(`${names.userName}: `)) {
|
||||
message.content = `${names.userName}: ${message.content}`;
|
||||
}
|
||||
}
|
||||
if (message.name && message.role !== 'system') {
|
||||
@ -716,7 +748,7 @@ export function mergeMessages(messages, charName, userName, strict) {
|
||||
mergedMessages.unshift({ role: 'user', content: PROMPT_PLACEHOLDER });
|
||||
}
|
||||
}
|
||||
return mergeMessages(mergedMessages, charName, userName, false);
|
||||
return mergeMessages(mergedMessages, names, false);
|
||||
}
|
||||
|
||||
return mergedMessages;
|
||||
|
Reference in New Issue
Block a user