Split Custom OAI prompt post-processing modes

This commit is contained in:
Cohee 2024-10-05 16:09:39 +03:00
parent 0637223bc2
commit 3b4a455ef8
5 changed files with 101 additions and 9 deletions

View File

@ -121,6 +121,8 @@ extras:
speechToTextModel: Xenova/whisper-small
textToSpeechModel: Xenova/speecht5_tts
# -- OPENAI CONFIGURATION --
# A placeholder message to use in strict prompt post-processing mode when the prompt doesn't start with a user message
promptPlaceholder: "[Start a new chat]"
openai:
# Will send a random user ID to OpenAI completion API
randomizeUserId: false

View File

@ -3102,7 +3102,8 @@
<h4 data-i18n="Prompt Post-Processing">Prompt Post-Processing</h4>
<select id="custom_prompt_post_processing" class="text_pole" title="Applies additional processing to the prompt before sending it to the API." data-i18n="[title]Applies additional processing to the prompt before sending it to the API.">
<option data-i18n="prompt_post_processing_none" value="">None</option>
<option value="claude">Claude</option>
<option value="merge">Merge consecutive roles</option>
<option value="strict">Strict (user first, alternating roles)</option>
</select>
</form>
<div id="01ai_form" data-source="01ai">

View File

@ -199,7 +199,10 @@ const continue_postfix_types = {
const custom_prompt_post_processing_types = {
NONE: '',
/** @deprecated Use MERGE instead. */
CLAUDE: 'claude',
MERGE: 'merge',
STRICT: 'strict',
};
const sensitiveFields = [
@ -3043,6 +3046,10 @@ function loadOpenAISettings(data, settings) {
setNamesBehaviorControls();
setContinuePostfixControls();
if (oai_settings.custom_prompt_post_processing === custom_prompt_post_processing_types.CLAUDE) {
oai_settings.custom_prompt_post_processing = custom_prompt_post_processing_types.MERGE;
}
$('#chat_completion_source').val(oai_settings.chat_completion_source).trigger('change');
$('#oai_max_context_unlocked').prop('checked', oai_settings.max_context_unlocked);
$('#custom_prompt_post_processing').val(oai_settings.custom_prompt_post_processing);

View File

@ -4,7 +4,7 @@ const fetch = require('node-fetch').default;
const { jsonParser } = require('../../express-common');
const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY, OPENROUTER_HEADERS } = require('../../constants');
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4, mergeObjectWithYaml, excludeKeysByYaml, color } = require('../../util');
const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertAI21Messages } = require('../../prompt-converters');
const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertAI21Messages, mergeMessages } = require('../../prompt-converters');
const CohereStream = require('../../cohere-stream');
const { readSecret, SECRET_KEYS } = require('../secrets');
@ -31,8 +31,11 @@ const API_AI21 = 'https://api.ai21.com/studio/v1';
*/
function postProcessPrompt(messages, type, charName, userName) {
switch (type) {
case 'merge':
case 'claude':
return convertClaudeMessages(messages, '', false, '', charName, userName).messages;
return mergeMessages(messages, charName, userName, false);
case 'strict':
return mergeMessages(messages, charName, userName, true);
default:
return messages;
}
@ -902,7 +905,7 @@ router.post('/generate', jsonParser, function (request, response) {
apiKey = readSecret(request.user.directories, SECRET_KEYS.PERPLEXITY);
headers = {};
bodyParams = {};
request.body.messages = postProcessPrompt(request.body.messages, 'claude', request.body.char_name, request.body.user_name);
request.body.messages = postProcessPrompt(request.body.messages, 'strict', request.body.char_name, request.body.user_name);
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.GROQ) {
apiUrl = API_GROQ;
apiKey = readSecret(request.user.directories, SECRET_KEYS.GROQ);

View File

@ -1,6 +1,8 @@
require('./polyfill.js');
const { getConfigValue } = require('./util.js');
const PROMPT_PLACEHOLDER = getConfigValue('promptPlaceholder', 'Let\'s get started.');
/**
* Convert a prompt from the ChatML objects to the format used by Claude.
* Mainly deprecated. Only used for counting tokens.
@ -122,7 +124,7 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFi
if (messages.length === 0 || (messages.length > 0 && messages[0].role !== 'user')) {
messages.unshift({
role: 'user',
content: humanMsgFix || '[Start a new chat]',
content: humanMsgFix || PROMPT_PLACEHOLDER,
});
}
}
@ -260,7 +262,6 @@ function convertCohereMessages(messages, charName = '', userName = '') {
'user': 'USER',
'assistant': 'CHATBOT',
};
const placeholder = '[Start a new chat]';
let systemPrompt = '';
// Collect all the system messages up until the first instance of a non-system message, and then remove them from the messages array.
@ -288,12 +289,12 @@ function convertCohereMessages(messages, charName = '', userName = '') {
if (messages.length === 0) {
messages.unshift({
role: 'user',
content: placeholder,
content: PROMPT_PLACEHOLDER,
});
}
const lastNonSystemMessageIndex = messages.findLastIndex(msg => msg.role === 'user' || msg.role === 'assistant');
const userPrompt = messages.slice(lastNonSystemMessageIndex).map(msg => msg.content).join('\n\n') || placeholder;
const userPrompt = messages.slice(lastNonSystemMessageIndex).map(msg => msg.content).join('\n\n') || PROMPT_PLACEHOLDER;
const chatHistory = messages.slice(0, lastNonSystemMessageIndex).map(msg => {
return {
@ -469,7 +470,7 @@ function convertAI21Messages(messages, charName = '', userName = '') {
if (messages.length === 0) {
messages.unshift({
role: 'user',
content: '[Start a new chat]',
content: PROMPT_PLACEHOLDER,
});
}
@ -553,6 +554,83 @@ function convertMistralMessages(messages, charName = '', userName = '') {
return messages;
}
/**
* Merge messages with the same consecutive role, removing names if they exist.
* @param {any[]} messages Messages to merge
* @param {string} charName Character name
* @param {string} userName User name
* @param {boolean} strict Enable strict mode: only allow one system message at the start, force user first message
* @returns {any[]} Merged messages
*/
function mergeMessages(messages, charName, userName, strict) {
let mergedMessages = [];
// Remove names from the messages
messages.forEach((message) => {
if (!message.content) {
message.content = '';
}
if (message.role === 'system' && message.name === 'example_assistant') {
if (charName && !message.content.startsWith(`${charName}: `)) {
message.content = `${charName}: ${message.content}`;
}
}
if (message.role === 'system' && message.name === 'example_user') {
if (userName && !message.content.startsWith(`${userName}: `)) {
message.content = `${userName}: ${message.content}`;
}
}
if (message.name && message.role !== 'system') {
if (!message.content.startsWith(`${message.name}: `)) {
message.content = `${message.name}: ${message.content}`;
}
}
if (message.role === 'tool') {
message.role = 'user';
}
delete message.name;
delete message.tool_calls;
delete message.tool_call_id;
});
// Squash consecutive messages with the same role
messages.forEach((message) => {
if (mergedMessages.length > 0 && mergedMessages[mergedMessages.length - 1].role === message.role && message.content) {
mergedMessages[mergedMessages.length - 1].content += '\n\n' + message.content;
} else {
mergedMessages.push(message);
}
});
// Prevent erroring out if the messages array is empty.
if (messages.length === 0) {
messages.unshift({
role: 'user',
content: PROMPT_PLACEHOLDER,
});
}
if (strict) {
for (let i = 0; i < mergedMessages.length; i++) {
// Force mid-prompt system messages to be user messages
if (i > 0 && mergedMessages[i].role === 'system') {
mergedMessages[i].role = 'user';
}
}
if (mergedMessages.length) {
if (mergedMessages[0].role === 'system' && (mergedMessages.length === 1 || mergedMessages[1].role !== 'user')) {
mergedMessages.splice(1, 0, { role: 'user', content: PROMPT_PLACEHOLDER });
}
else if (mergedMessages[0].role !== 'system' && mergedMessages[0].role !== 'user') {
mergedMessages.unshift({ role: 'user', content: PROMPT_PLACEHOLDER });
}
}
return mergeMessages(mergedMessages, charName, userName, false);
}
return mergedMessages;
}
/**
* Convert a prompt from the ChatML objects to the format used by Text Completion API.
* @param {object[]} messages Array of messages
@ -586,4 +664,5 @@ module.exports = {
convertCohereMessages,
convertMistralMessages,
convertAI21Messages,
mergeMessages,
};