Implement user alignment message
This commit is contained in:
parent
d02c93e84f
commit
4b6a3054b1
|
@ -3264,6 +3264,8 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
|
||||||
|
|
||||||
let chat2 = [];
|
let chat2 = [];
|
||||||
let continue_mag = '';
|
let continue_mag = '';
|
||||||
|
const userMessageIndices = [];
|
||||||
|
|
||||||
for (let i = coreChat.length - 1, j = 0; i >= 0; i--, j++) {
|
for (let i = coreChat.length - 1, j = 0; i >= 0; i--, j++) {
|
||||||
if (main_api == 'openai') {
|
if (main_api == 'openai') {
|
||||||
chat2[i] = coreChat[j].mes;
|
chat2[i] = coreChat[j].mes;
|
||||||
|
@ -3291,6 +3293,22 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
|
||||||
chat2[i] = chat2[i].slice(0, chat2[i].lastIndexOf(coreChat[j].mes) + coreChat[j].mes.length);
|
chat2[i] = chat2[i].slice(0, chat2[i].lastIndexOf(coreChat[j].mes) + coreChat[j].mes.length);
|
||||||
continue_mag = coreChat[j].mes;
|
continue_mag = coreChat[j].mes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (coreChat[j].is_user) {
|
||||||
|
userMessageIndices.push(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let addUserAlignment = isInstruct && power_user.instruct.user_alignment_message;
|
||||||
|
let userAlignmentMessage = '';
|
||||||
|
|
||||||
|
if (addUserAlignment) {
|
||||||
|
const alignmentMessage = {
|
||||||
|
name: name1,
|
||||||
|
mes: power_user.instruct.user_alignment_message,
|
||||||
|
is_user: true,
|
||||||
|
};
|
||||||
|
userAlignmentMessage = formatMessageHistoryItem(alignmentMessage, isInstruct, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add persona description to prompt
|
// Add persona description to prompt
|
||||||
|
@ -3349,6 +3367,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
|
||||||
allAnchors,
|
allAnchors,
|
||||||
quiet_prompt,
|
quiet_prompt,
|
||||||
cyclePrompt,
|
cyclePrompt,
|
||||||
|
userAlignmentMessage,
|
||||||
].join('').replace(/\r/gm, '');
|
].join('').replace(/\r/gm, '');
|
||||||
return getTokenCount(encodeString, power_user.token_padding);
|
return getTokenCount(encodeString, power_user.token_padding);
|
||||||
}
|
}
|
||||||
|
@ -3367,16 +3386,19 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
|
||||||
// Collect enough messages to fill the context
|
// Collect enough messages to fill the context
|
||||||
let arrMes = [];
|
let arrMes = [];
|
||||||
let tokenCount = getMessagesTokenCount();
|
let tokenCount = getMessagesTokenCount();
|
||||||
for (let item of chat2) {
|
let lastAddedIndex = -1;
|
||||||
|
for (let i = 0; i < chat2.length; i++) {
|
||||||
// not needed for OAI prompting
|
// not needed for OAI prompting
|
||||||
if (main_api == 'openai') {
|
if (main_api == 'openai') {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const item = chat2[i];
|
||||||
tokenCount += getTokenCount(item.replace(/\r/gm, ''));
|
tokenCount += getTokenCount(item.replace(/\r/gm, ''));
|
||||||
chatString = item + chatString;
|
chatString = item + chatString;
|
||||||
if (tokenCount < this_max_context) {
|
if (tokenCount < this_max_context) {
|
||||||
arrMes[arrMes.length] = item;
|
arrMes[arrMes.length] = item;
|
||||||
|
lastAddedIndex = i;
|
||||||
} else {
|
} else {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -3385,8 +3407,21 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
|
||||||
await delay(1);
|
await delay(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const stoppedAtUser = userMessageIndices.includes(lastAddedIndex);
|
||||||
|
if (addUserAlignment && !stoppedAtUser) {
|
||||||
|
tokenCount += getTokenCount(userAlignmentMessage.replace(/\r/gm, ''));
|
||||||
|
chatString = userAlignmentMessage + chatString;
|
||||||
|
arrMes[arrMes.length] = userAlignmentMessage;
|
||||||
|
// Injected indices shift by 1 for user alignment message at the beginning
|
||||||
|
injectedIndices.forEach((value, index) => (injectedIndices[index] = value + 1));
|
||||||
|
injectedIndices.push(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter injections which don't fit in the context
|
||||||
|
injectedIndices = injectedIndices.filter(value => value < arrMes.length);
|
||||||
|
|
||||||
if (main_api !== 'openai') {
|
if (main_api !== 'openai') {
|
||||||
setInContextMessages(arrMes.length, type);
|
setInContextMessages(arrMes.length - injectedIndices.length, type);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Estimate how many unpinned example messages fit in the context
|
// Estimate how many unpinned example messages fit in the context
|
||||||
|
@ -3664,7 +3699,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
|
||||||
};
|
};
|
||||||
|
|
||||||
finalMesSend.forEach((item, i) => {
|
finalMesSend.forEach((item, i) => {
|
||||||
item.injected = Array.isArray(injectedIndices) && injectedIndices.includes(i);
|
item.injected = injectedIndices.includes(finalMesSend.length - i - 1);
|
||||||
});
|
});
|
||||||
|
|
||||||
let data = {
|
let data = {
|
||||||
|
@ -4030,10 +4065,6 @@ function doChatInject(messages, isContinue) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (let i = 0; i < injectedIndices.length; i++) {
|
|
||||||
injectedIndices[i] = messages.length - injectedIndices[i] - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
messages.reverse();
|
messages.reverse();
|
||||||
return injectedIndices;
|
return injectedIndices;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue