Split Custom OAI prompt post-processing modes

This commit is contained in:
Cohee
2024-10-05 16:09:39 +03:00
parent 0637223bc2
commit 3b4a455ef8
5 changed files with 101 additions and 9 deletions

View File

@ -1,6 +1,8 @@
require('./polyfill.js');
const { getConfigValue } = require('./util.js');
const PROMPT_PLACEHOLDER = getConfigValue('promptPlaceholder', 'Let\'s get started.');
/**
* Convert a prompt from the ChatML objects to the format used by Claude.
* Mainly deprecated. Only used for counting tokens.
@ -122,7 +124,7 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFi
if (messages.length === 0 || (messages.length > 0 && messages[0].role !== 'user')) {
messages.unshift({
role: 'user',
content: humanMsgFix || '[Start a new chat]',
content: humanMsgFix || PROMPT_PLACEHOLDER,
});
}
}
@ -260,7 +262,6 @@ function convertCohereMessages(messages, charName = '', userName = '') {
'user': 'USER',
'assistant': 'CHATBOT',
};
const placeholder = '[Start a new chat]';
let systemPrompt = '';
// Collect all the system messages up until the first instance of a non-system message, and then remove them from the messages array.
@ -288,12 +289,12 @@ function convertCohereMessages(messages, charName = '', userName = '') {
if (messages.length === 0) {
messages.unshift({
role: 'user',
content: placeholder,
content: PROMPT_PLACEHOLDER,
});
}
const lastNonSystemMessageIndex = messages.findLastIndex(msg => msg.role === 'user' || msg.role === 'assistant');
const userPrompt = messages.slice(lastNonSystemMessageIndex).map(msg => msg.content).join('\n\n') || placeholder;
const userPrompt = messages.slice(lastNonSystemMessageIndex).map(msg => msg.content).join('\n\n') || PROMPT_PLACEHOLDER;
const chatHistory = messages.slice(0, lastNonSystemMessageIndex).map(msg => {
return {
@ -469,7 +470,7 @@ function convertAI21Messages(messages, charName = '', userName = '') {
if (messages.length === 0) {
messages.unshift({
role: 'user',
content: '[Start a new chat]',
content: PROMPT_PLACEHOLDER,
});
}
@ -553,6 +554,83 @@ function convertMistralMessages(messages, charName = '', userName = '') {
return messages;
}
/**
* Merge messages with the same consecutive role, removing names if they exist.
* @param {any[]} messages Messages to merge
* @param {string} charName Character name
* @param {string} userName User name
* @param {boolean} strict Enable strict mode: only allow one system message at the start, force user first message
* @returns {any[]} Merged messages
*/
function mergeMessages(messages, charName, userName, strict) {
let mergedMessages = [];
// Remove names from the messages
messages.forEach((message) => {
if (!message.content) {
message.content = '';
}
if (message.role === 'system' && message.name === 'example_assistant') {
if (charName && !message.content.startsWith(`${charName}: `)) {
message.content = `${charName}: ${message.content}`;
}
}
if (message.role === 'system' && message.name === 'example_user') {
if (userName && !message.content.startsWith(`${userName}: `)) {
message.content = `${userName}: ${message.content}`;
}
}
if (message.name && message.role !== 'system') {
if (!message.content.startsWith(`${message.name}: `)) {
message.content = `${message.name}: ${message.content}`;
}
}
if (message.role === 'tool') {
message.role = 'user';
}
delete message.name;
delete message.tool_calls;
delete message.tool_call_id;
});
// Squash consecutive messages with the same role
messages.forEach((message) => {
if (mergedMessages.length > 0 && mergedMessages[mergedMessages.length - 1].role === message.role && message.content) {
mergedMessages[mergedMessages.length - 1].content += '\n\n' + message.content;
} else {
mergedMessages.push(message);
}
});
// Prevent erroring out if the messages array is empty.
if (messages.length === 0) {
messages.unshift({
role: 'user',
content: PROMPT_PLACEHOLDER,
});
}
if (strict) {
for (let i = 0; i < mergedMessages.length; i++) {
// Force mid-prompt system messages to be user messages
if (i > 0 && mergedMessages[i].role === 'system') {
mergedMessages[i].role = 'user';
}
}
if (mergedMessages.length) {
if (mergedMessages[0].role === 'system' && (mergedMessages.length === 1 || mergedMessages[1].role !== 'user')) {
mergedMessages.splice(1, 0, { role: 'user', content: PROMPT_PLACEHOLDER });
}
else if (mergedMessages[0].role !== 'system' && mergedMessages[0].role !== 'user') {
mergedMessages.unshift({ role: 'user', content: PROMPT_PLACEHOLDER });
}
}
return mergeMessages(mergedMessages, charName, userName, false);
}
return mergedMessages;
}
/**
* Convert a prompt from the ChatML objects to the format used by Text Completion API.
* @param {object[]} messages Array of messages
@ -586,4 +664,5 @@ module.exports = {
convertCohereMessages,
convertMistralMessages,
convertAI21Messages,
mergeMessages,
};