From 4b6a3054b1a63e89ab6b2e0f4b1b095ecea0d196 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Thu, 28 Mar 2024 02:27:37 +0200 Subject: [PATCH] Implement user alignment message --- public/script.js | 45 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/public/script.js b/public/script.js index b07ca95be..634c3b5f4 100644 --- a/public/script.js +++ b/public/script.js @@ -3264,6 +3264,8 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu let chat2 = []; let continue_mag = ''; + const userMessageIndices = []; + for (let i = coreChat.length - 1, j = 0; i >= 0; i--, j++) { if (main_api == 'openai') { chat2[i] = coreChat[j].mes; @@ -3291,6 +3293,22 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu chat2[i] = chat2[i].slice(0, chat2[i].lastIndexOf(coreChat[j].mes) + coreChat[j].mes.length); continue_mag = coreChat[j].mes; } + + if (coreChat[j].is_user) { + userMessageIndices.push(i); + } + } + + let addUserAlignment = isInstruct && power_user.instruct.user_alignment_message; + let userAlignmentMessage = ''; + + if (addUserAlignment) { + const alignmentMessage = { + name: name1, + mes: power_user.instruct.user_alignment_message, + is_user: true, + }; + userAlignmentMessage = formatMessageHistoryItem(alignmentMessage, isInstruct, false); } // Add persona description to prompt @@ -3349,6 +3367,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu allAnchors, quiet_prompt, cyclePrompt, + userAlignmentMessage, ].join('').replace(/\r/gm, ''); return getTokenCount(encodeString, power_user.token_padding); } @@ -3367,16 +3386,19 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu // Collect enough messages to fill the context let arrMes = []; let tokenCount = getMessagesTokenCount(); - for (let item of chat2) { + let lastAddedIndex = -1; + for (let i = 0; i < chat2.length; i++) { // not needed for OAI prompting if (main_api == 'openai') { break; } + const item = chat2[i]; tokenCount += getTokenCount(item.replace(/\r/gm, '')); chatString = item + chatString; if (tokenCount < this_max_context) { arrMes[arrMes.length] = item; + lastAddedIndex = i; } else { break; } @@ -3385,8 +3407,21 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu await delay(1); } + const stoppedAtUser = userMessageIndices.includes(lastAddedIndex); + if (addUserAlignment && !stoppedAtUser) { + tokenCount += getTokenCount(userAlignmentMessage.replace(/\r/gm, '')); + chatString = userAlignmentMessage + chatString; + arrMes[arrMes.length] = userAlignmentMessage; + // Injected indices shift by 1 for user alignment message at the beginning + injectedIndices.forEach((value, index) => (injectedIndices[index] = value + 1)); + injectedIndices.push(0); + } + + // Filter injections which don't fit in the context + injectedIndices = injectedIndices.filter(value => value < arrMes.length); + if (main_api !== 'openai') { - setInContextMessages(arrMes.length, type); + setInContextMessages(arrMes.length - injectedIndices.length, type); } // Estimate how many unpinned example messages fit in the context @@ -3664,7 +3699,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu }; finalMesSend.forEach((item, i) => { - item.injected = Array.isArray(injectedIndices) && injectedIndices.includes(i); + item.injected = injectedIndices.includes(finalMesSend.length - i - 1); }); let data = { @@ -4030,10 +4065,6 @@ function doChatInject(messages, isContinue) { } } - for (let i = 0; i < injectedIndices.length; i++) { - injectedIndices[i] = messages.length - injectedIndices[i] - 1; - } - messages.reverse(); return injectedIndices; }