Pre-populate chat history with injections

This commit is contained in:
Cohee 2024-03-28 02:59:52 +02:00
parent 4b6a3054b1
commit 689af3151a
1 changed files with 49 additions and 15 deletions

View File

@ -3384,21 +3384,19 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
} }
// Collect enough messages to fill the context // Collect enough messages to fill the context
let arrMes = []; let arrMes = new Array(chat2.length);
let tokenCount = getMessagesTokenCount(); let tokenCount = getMessagesTokenCount();
let lastAddedIndex = -1; let lastAddedIndex = -1;
for (let i = 0; i < chat2.length; i++) {
// not needed for OAI prompting
if (main_api == 'openai') {
break;
}
const item = chat2[i]; // Pre-allocate all injections first.
// If it doesn't fit - user shot himself in the foot
for (const index of injectedIndices) {
const item = chat2[index];
tokenCount += getTokenCount(item.replace(/\r/gm, '')); tokenCount += getTokenCount(item.replace(/\r/gm, ''));
chatString = item + chatString; chatString = item + chatString;
if (tokenCount < this_max_context) { if (tokenCount < this_max_context) {
arrMes[arrMes.length] = item; arrMes[index] = item;
lastAddedIndex = i; lastAddedIndex = Math.max(lastAddedIndex, index);
} else { } else {
break; break;
} }
@ -3407,18 +3405,54 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
await delay(1); await delay(1);
} }
for (let i = 0; i < chat2.length; i++) {
// not needed for OAI prompting
if (main_api == 'openai') {
break;
}
// Skip already injected messages
if (arrMes[i] !== undefined) {
continue;
}
const item = chat2[i];
tokenCount += getTokenCount(item.replace(/\r/gm, ''));
chatString = item + chatString;
if (tokenCount < this_max_context) {
arrMes[i] = item;
lastAddedIndex = Math.max(lastAddedIndex, i);
} else {
break;
}
// Prevent UI thread lock on tokenization
await delay(1);
}
// Add user alignment message if last message is not a user message
const stoppedAtUser = userMessageIndices.includes(lastAddedIndex); const stoppedAtUser = userMessageIndices.includes(lastAddedIndex);
if (addUserAlignment && !stoppedAtUser) { if (addUserAlignment && !stoppedAtUser) {
tokenCount += getTokenCount(userAlignmentMessage.replace(/\r/gm, '')); tokenCount += getTokenCount(userAlignmentMessage.replace(/\r/gm, ''));
chatString = userAlignmentMessage + chatString; chatString = userAlignmentMessage + chatString;
arrMes[arrMes.length] = userAlignmentMessage; arrMes.push(userAlignmentMessage);
// Injected indices shift by 1 for user alignment message at the beginning injectedIndices.push(arrMes.length - 1);
injectedIndices.forEach((value, index) => (injectedIndices[index] = value + 1));
injectedIndices.push(0);
} }
// Filter injections which don't fit in the context // Unsparse the array. Adjust injected indices
injectedIndices = injectedIndices.filter(value => value < arrMes.length); const newArrMes = [];
const newInjectedIndices = [];
for (let i = 0; i < arrMes.length; i++) {
if (arrMes[i] !== undefined) {
newArrMes.push(arrMes[i]);
if (injectedIndices.includes(i)) {
newInjectedIndices.push(newArrMes.length - 1);
}
}
}
arrMes = newArrMes;
injectedIndices = newInjectedIndices;
if (main_api !== 'openai') { if (main_api !== 'openai') {
setInContextMessages(arrMes.length - injectedIndices.length, type); setInContextMessages(arrMes.length - injectedIndices.length, type);