Implement user alignment message

This commit is contained in:
Cohee 2024-03-28 02:27:37 +02:00
parent d02c93e84f
commit 4b6a3054b1
1 changed files with 38 additions and 7 deletions

View File

@ -3264,6 +3264,8 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
let chat2 = []; let chat2 = [];
let continue_mag = ''; let continue_mag = '';
const userMessageIndices = [];
for (let i = coreChat.length - 1, j = 0; i >= 0; i--, j++) { for (let i = coreChat.length - 1, j = 0; i >= 0; i--, j++) {
if (main_api == 'openai') { if (main_api == 'openai') {
chat2[i] = coreChat[j].mes; chat2[i] = coreChat[j].mes;
@ -3291,6 +3293,22 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
chat2[i] = chat2[i].slice(0, chat2[i].lastIndexOf(coreChat[j].mes) + coreChat[j].mes.length); chat2[i] = chat2[i].slice(0, chat2[i].lastIndexOf(coreChat[j].mes) + coreChat[j].mes.length);
continue_mag = coreChat[j].mes; continue_mag = coreChat[j].mes;
} }
if (coreChat[j].is_user) {
userMessageIndices.push(i);
}
}
let addUserAlignment = isInstruct && power_user.instruct.user_alignment_message;
let userAlignmentMessage = '';
if (addUserAlignment) {
const alignmentMessage = {
name: name1,
mes: power_user.instruct.user_alignment_message,
is_user: true,
};
userAlignmentMessage = formatMessageHistoryItem(alignmentMessage, isInstruct, false);
} }
// Add persona description to prompt // Add persona description to prompt
@ -3349,6 +3367,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
allAnchors, allAnchors,
quiet_prompt, quiet_prompt,
cyclePrompt, cyclePrompt,
userAlignmentMessage,
].join('').replace(/\r/gm, ''); ].join('').replace(/\r/gm, '');
return getTokenCount(encodeString, power_user.token_padding); return getTokenCount(encodeString, power_user.token_padding);
} }
@ -3367,16 +3386,19 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
// Collect enough messages to fill the context // Collect enough messages to fill the context
let arrMes = []; let arrMes = [];
let tokenCount = getMessagesTokenCount(); let tokenCount = getMessagesTokenCount();
for (let item of chat2) { let lastAddedIndex = -1;
for (let i = 0; i < chat2.length; i++) {
// not needed for OAI prompting // not needed for OAI prompting
if (main_api == 'openai') { if (main_api == 'openai') {
break; break;
} }
const item = chat2[i];
tokenCount += getTokenCount(item.replace(/\r/gm, '')); tokenCount += getTokenCount(item.replace(/\r/gm, ''));
chatString = item + chatString; chatString = item + chatString;
if (tokenCount < this_max_context) { if (tokenCount < this_max_context) {
arrMes[arrMes.length] = item; arrMes[arrMes.length] = item;
lastAddedIndex = i;
} else { } else {
break; break;
} }
@ -3385,8 +3407,21 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
await delay(1); await delay(1);
} }
const stoppedAtUser = userMessageIndices.includes(lastAddedIndex);
if (addUserAlignment && !stoppedAtUser) {
tokenCount += getTokenCount(userAlignmentMessage.replace(/\r/gm, ''));
chatString = userAlignmentMessage + chatString;
arrMes[arrMes.length] = userAlignmentMessage;
// Injected indices shift by 1 for user alignment message at the beginning
injectedIndices.forEach((value, index) => (injectedIndices[index] = value + 1));
injectedIndices.push(0);
}
// Filter injections which don't fit in the context
injectedIndices = injectedIndices.filter(value => value < arrMes.length);
if (main_api !== 'openai') { if (main_api !== 'openai') {
setInContextMessages(arrMes.length, type); setInContextMessages(arrMes.length - injectedIndices.length, type);
} }
// Estimate how many unpinned example messages fit in the context // Estimate how many unpinned example messages fit in the context
@ -3664,7 +3699,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
}; };
finalMesSend.forEach((item, i) => { finalMesSend.forEach((item, i) => {
item.injected = Array.isArray(injectedIndices) && injectedIndices.includes(i); item.injected = injectedIndices.includes(finalMesSend.length - i - 1);
}); });
let data = { let data = {
@ -4030,10 +4065,6 @@ function doChatInject(messages, isContinue) {
} }
} }
for (let i = 0; i < injectedIndices.length; i++) {
injectedIndices[i] = messages.length - injectedIndices[i] - 1;
}
messages.reverse(); messages.reverse();
return injectedIndices; return injectedIndices;
} }