Implement user alignment message

This commit is contained in:
Cohee 2024-03-28 02:27:37 +02:00
parent d02c93e84f
commit 4b6a3054b1
1 changed files with 38 additions and 7 deletions

View File

@ -3264,6 +3264,8 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
let chat2 = [];
let continue_mag = '';
const userMessageIndices = [];
for (let i = coreChat.length - 1, j = 0; i >= 0; i--, j++) {
if (main_api == 'openai') {
chat2[i] = coreChat[j].mes;
@ -3291,6 +3293,22 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
chat2[i] = chat2[i].slice(0, chat2[i].lastIndexOf(coreChat[j].mes) + coreChat[j].mes.length);
continue_mag = coreChat[j].mes;
}
if (coreChat[j].is_user) {
userMessageIndices.push(i);
}
}
let addUserAlignment = isInstruct && power_user.instruct.user_alignment_message;
let userAlignmentMessage = '';
if (addUserAlignment) {
const alignmentMessage = {
name: name1,
mes: power_user.instruct.user_alignment_message,
is_user: true,
};
userAlignmentMessage = formatMessageHistoryItem(alignmentMessage, isInstruct, false);
}
// Add persona description to prompt
@ -3349,6 +3367,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
allAnchors,
quiet_prompt,
cyclePrompt,
userAlignmentMessage,
].join('').replace(/\r/gm, '');
return getTokenCount(encodeString, power_user.token_padding);
}
@ -3367,16 +3386,19 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
// Collect enough messages to fill the context
let arrMes = [];
let tokenCount = getMessagesTokenCount();
for (let item of chat2) {
let lastAddedIndex = -1;
for (let i = 0; i < chat2.length; i++) {
// not needed for OAI prompting
if (main_api == 'openai') {
break;
}
const item = chat2[i];
tokenCount += getTokenCount(item.replace(/\r/gm, ''));
chatString = item + chatString;
if (tokenCount < this_max_context) {
arrMes[arrMes.length] = item;
lastAddedIndex = i;
} else {
break;
}
@ -3385,8 +3407,21 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
await delay(1);
}
const stoppedAtUser = userMessageIndices.includes(lastAddedIndex);
if (addUserAlignment && !stoppedAtUser) {
tokenCount += getTokenCount(userAlignmentMessage.replace(/\r/gm, ''));
chatString = userAlignmentMessage + chatString;
arrMes[arrMes.length] = userAlignmentMessage;
// Injected indices shift by 1 for user alignment message at the beginning
injectedIndices.forEach((value, index) => (injectedIndices[index] = value + 1));
injectedIndices.push(0);
}
// Filter injections which don't fit in the context
injectedIndices = injectedIndices.filter(value => value < arrMes.length);
if (main_api !== 'openai') {
setInContextMessages(arrMes.length, type);
setInContextMessages(arrMes.length - injectedIndices.length, type);
}
// Estimate how many unpinned example messages fit in the context
@ -3664,7 +3699,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
};
finalMesSend.forEach((item, i) => {
item.injected = Array.isArray(injectedIndices) && injectedIndices.includes(i);
item.injected = injectedIndices.includes(finalMesSend.length - i - 1);
});
let data = {
@ -4030,10 +4065,6 @@ function doChatInject(messages, isContinue) {
}
}
for (let i = 0; i < injectedIndices.length; i++) {
injectedIndices[i] = messages.length - injectedIndices[i] - 1;
}
messages.reverse();
return injectedIndices;
}