Split Custom OAI prompt post-processing modes

This commit is contained in:
Cohee
2024-10-05 16:09:39 +03:00
parent 0637223bc2
commit 3b4a455ef8
5 changed files with 101 additions and 9 deletions

View File

@ -121,6 +121,8 @@ extras:
speechToTextModel: Xenova/whisper-small speechToTextModel: Xenova/whisper-small
textToSpeechModel: Xenova/speecht5_tts textToSpeechModel: Xenova/speecht5_tts
# -- OPENAI CONFIGURATION -- # -- OPENAI CONFIGURATION --
# A placeholder message to use in strict prompt post-processing mode when the prompt doesn't start with a user message
promptPlaceholder: "[Start a new chat]"
openai: openai:
# Will send a random user ID to OpenAI completion API # Will send a random user ID to OpenAI completion API
randomizeUserId: false randomizeUserId: false

View File

@ -3102,7 +3102,8 @@
<h4 data-i18n="Prompt Post-Processing">Prompt Post-Processing</h4> <h4 data-i18n="Prompt Post-Processing">Prompt Post-Processing</h4>
<select id="custom_prompt_post_processing" class="text_pole" title="Applies additional processing to the prompt before sending it to the API." data-i18n="[title]Applies additional processing to the prompt before sending it to the API."> <select id="custom_prompt_post_processing" class="text_pole" title="Applies additional processing to the prompt before sending it to the API." data-i18n="[title]Applies additional processing to the prompt before sending it to the API.">
<option data-i18n="prompt_post_processing_none" value="">None</option> <option data-i18n="prompt_post_processing_none" value="">None</option>
<option value="claude">Claude</option> <option value="merge">Merge consecutive roles</option>
<option value="strict">Strict (user first, alternating roles)</option>
</select> </select>
</form> </form>
<div id="01ai_form" data-source="01ai"> <div id="01ai_form" data-source="01ai">

View File

@ -199,7 +199,10 @@ const continue_postfix_types = {
const custom_prompt_post_processing_types = { const custom_prompt_post_processing_types = {
NONE: '', NONE: '',
/** @deprecated Use MERGE instead. */
CLAUDE: 'claude', CLAUDE: 'claude',
MERGE: 'merge',
STRICT: 'strict',
}; };
const sensitiveFields = [ const sensitiveFields = [
@ -3043,6 +3046,10 @@ function loadOpenAISettings(data, settings) {
setNamesBehaviorControls(); setNamesBehaviorControls();
setContinuePostfixControls(); setContinuePostfixControls();
if (oai_settings.custom_prompt_post_processing === custom_prompt_post_processing_types.CLAUDE) {
oai_settings.custom_prompt_post_processing = custom_prompt_post_processing_types.MERGE;
}
$('#chat_completion_source').val(oai_settings.chat_completion_source).trigger('change'); $('#chat_completion_source').val(oai_settings.chat_completion_source).trigger('change');
$('#oai_max_context_unlocked').prop('checked', oai_settings.max_context_unlocked); $('#oai_max_context_unlocked').prop('checked', oai_settings.max_context_unlocked);
$('#custom_prompt_post_processing').val(oai_settings.custom_prompt_post_processing); $('#custom_prompt_post_processing').val(oai_settings.custom_prompt_post_processing);

View File

@ -4,7 +4,7 @@ const fetch = require('node-fetch').default;
const { jsonParser } = require('../../express-common'); const { jsonParser } = require('../../express-common');
const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY, OPENROUTER_HEADERS } = require('../../constants'); const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY, OPENROUTER_HEADERS } = require('../../constants');
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4, mergeObjectWithYaml, excludeKeysByYaml, color } = require('../../util'); const { forwardFetchResponse, getConfigValue, tryParse, uuidv4, mergeObjectWithYaml, excludeKeysByYaml, color } = require('../../util');
const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertAI21Messages } = require('../../prompt-converters'); const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertAI21Messages, mergeMessages } = require('../../prompt-converters');
const CohereStream = require('../../cohere-stream'); const CohereStream = require('../../cohere-stream');
const { readSecret, SECRET_KEYS } = require('../secrets'); const { readSecret, SECRET_KEYS } = require('../secrets');
@ -31,8 +31,11 @@ const API_AI21 = 'https://api.ai21.com/studio/v1';
*/ */
function postProcessPrompt(messages, type, charName, userName) { function postProcessPrompt(messages, type, charName, userName) {
switch (type) { switch (type) {
case 'merge':
case 'claude': case 'claude':
return convertClaudeMessages(messages, '', false, '', charName, userName).messages; return mergeMessages(messages, charName, userName, false);
case 'strict':
return mergeMessages(messages, charName, userName, true);
default: default:
return messages; return messages;
} }
@ -902,7 +905,7 @@ router.post('/generate', jsonParser, function (request, response) {
apiKey = readSecret(request.user.directories, SECRET_KEYS.PERPLEXITY); apiKey = readSecret(request.user.directories, SECRET_KEYS.PERPLEXITY);
headers = {}; headers = {};
bodyParams = {}; bodyParams = {};
request.body.messages = postProcessPrompt(request.body.messages, 'claude', request.body.char_name, request.body.user_name); request.body.messages = postProcessPrompt(request.body.messages, 'strict', request.body.char_name, request.body.user_name);
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.GROQ) { } else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.GROQ) {
apiUrl = API_GROQ; apiUrl = API_GROQ;
apiKey = readSecret(request.user.directories, SECRET_KEYS.GROQ); apiKey = readSecret(request.user.directories, SECRET_KEYS.GROQ);

View File

@ -1,6 +1,8 @@
require('./polyfill.js'); require('./polyfill.js');
const { getConfigValue } = require('./util.js'); const { getConfigValue } = require('./util.js');
const PROMPT_PLACEHOLDER = getConfigValue('promptPlaceholder', 'Let\'s get started.');
/** /**
* Convert a prompt from the ChatML objects to the format used by Claude. * Convert a prompt from the ChatML objects to the format used by Claude.
* Mainly deprecated. Only used for counting tokens. * Mainly deprecated. Only used for counting tokens.
@ -122,7 +124,7 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFi
if (messages.length === 0 || (messages.length > 0 && messages[0].role !== 'user')) { if (messages.length === 0 || (messages.length > 0 && messages[0].role !== 'user')) {
messages.unshift({ messages.unshift({
role: 'user', role: 'user',
content: humanMsgFix || '[Start a new chat]', content: humanMsgFix || PROMPT_PLACEHOLDER,
}); });
} }
} }
@ -260,7 +262,6 @@ function convertCohereMessages(messages, charName = '', userName = '') {
'user': 'USER', 'user': 'USER',
'assistant': 'CHATBOT', 'assistant': 'CHATBOT',
}; };
const placeholder = '[Start a new chat]';
let systemPrompt = ''; let systemPrompt = '';
// Collect all the system messages up until the first instance of a non-system message, and then remove them from the messages array. // Collect all the system messages up until the first instance of a non-system message, and then remove them from the messages array.
@ -288,12 +289,12 @@ function convertCohereMessages(messages, charName = '', userName = '') {
if (messages.length === 0) { if (messages.length === 0) {
messages.unshift({ messages.unshift({
role: 'user', role: 'user',
content: placeholder, content: PROMPT_PLACEHOLDER,
}); });
} }
const lastNonSystemMessageIndex = messages.findLastIndex(msg => msg.role === 'user' || msg.role === 'assistant'); const lastNonSystemMessageIndex = messages.findLastIndex(msg => msg.role === 'user' || msg.role === 'assistant');
const userPrompt = messages.slice(lastNonSystemMessageIndex).map(msg => msg.content).join('\n\n') || placeholder; const userPrompt = messages.slice(lastNonSystemMessageIndex).map(msg => msg.content).join('\n\n') || PROMPT_PLACEHOLDER;
const chatHistory = messages.slice(0, lastNonSystemMessageIndex).map(msg => { const chatHistory = messages.slice(0, lastNonSystemMessageIndex).map(msg => {
return { return {
@ -469,7 +470,7 @@ function convertAI21Messages(messages, charName = '', userName = '') {
if (messages.length === 0) { if (messages.length === 0) {
messages.unshift({ messages.unshift({
role: 'user', role: 'user',
content: '[Start a new chat]', content: PROMPT_PLACEHOLDER,
}); });
} }
@ -553,6 +554,83 @@ function convertMistralMessages(messages, charName = '', userName = '') {
return messages; return messages;
} }
/**
* Merge messages with the same consecutive role, removing names if they exist.
* @param {any[]} messages Messages to merge
* @param {string} charName Character name
* @param {string} userName User name
* @param {boolean} strict Enable strict mode: only allow one system message at the start, force user first message
* @returns {any[]} Merged messages
*/
function mergeMessages(messages, charName, userName, strict) {
let mergedMessages = [];
// Remove names from the messages
messages.forEach((message) => {
if (!message.content) {
message.content = '';
}
if (message.role === 'system' && message.name === 'example_assistant') {
if (charName && !message.content.startsWith(`${charName}: `)) {
message.content = `${charName}: ${message.content}`;
}
}
if (message.role === 'system' && message.name === 'example_user') {
if (userName && !message.content.startsWith(`${userName}: `)) {
message.content = `${userName}: ${message.content}`;
}
}
if (message.name && message.role !== 'system') {
if (!message.content.startsWith(`${message.name}: `)) {
message.content = `${message.name}: ${message.content}`;
}
}
if (message.role === 'tool') {
message.role = 'user';
}
delete message.name;
delete message.tool_calls;
delete message.tool_call_id;
});
// Squash consecutive messages with the same role
messages.forEach((message) => {
if (mergedMessages.length > 0 && mergedMessages[mergedMessages.length - 1].role === message.role && message.content) {
mergedMessages[mergedMessages.length - 1].content += '\n\n' + message.content;
} else {
mergedMessages.push(message);
}
});
// Prevent erroring out if the messages array is empty.
if (messages.length === 0) {
messages.unshift({
role: 'user',
content: PROMPT_PLACEHOLDER,
});
}
if (strict) {
for (let i = 0; i < mergedMessages.length; i++) {
// Force mid-prompt system messages to be user messages
if (i > 0 && mergedMessages[i].role === 'system') {
mergedMessages[i].role = 'user';
}
}
if (mergedMessages.length) {
if (mergedMessages[0].role === 'system' && (mergedMessages.length === 1 || mergedMessages[1].role !== 'user')) {
mergedMessages.splice(1, 0, { role: 'user', content: PROMPT_PLACEHOLDER });
}
else if (mergedMessages[0].role !== 'system' && mergedMessages[0].role !== 'user') {
mergedMessages.unshift({ role: 'user', content: PROMPT_PLACEHOLDER });
}
}
return mergeMessages(mergedMessages, charName, userName, false);
}
return mergedMessages;
}
/** /**
* Convert a prompt from the ChatML objects to the format used by Text Completion API. * Convert a prompt from the ChatML objects to the format used by Text Completion API.
* @param {object[]} messages Array of messages * @param {object[]} messages Array of messages
@ -586,4 +664,5 @@ module.exports = {
convertCohereMessages, convertCohereMessages,
convertMistralMessages, convertMistralMessages,
convertAI21Messages, convertAI21Messages,
mergeMessages,
}; };