Ignore recurse buffer for min activation steps

This commit is contained in:
Cohee
2024-07-04 00:28:34 +03:00
parent 47b679202f
commit 35b7fc3186

View File

@@ -157,6 +157,11 @@ class WorldInfoBuffer {
*/ */
#recurseBuffer = []; #recurseBuffer = [];
/**
* @type {string[]} Array of strings added by prompt injections that are valid for the current scan
*/
#injectBuffer = [];
/** /**
* @type {number} The skew of the global scan depth. Used in "min activations" * @type {number} The skew of the global scan depth. Used in "min activations"
*/ */
@@ -206,9 +211,10 @@ class WorldInfoBuffer {
/** /**
* Gets all messages up to the given depth + recursion buffer. * Gets all messages up to the given depth + recursion buffer.
* @param {WIScanEntry} entry The entry that triggered the scan * @param {WIScanEntry} entry The entry that triggered the scan
* @param {number} scanState The state of the scan
* @returns {string} A slice of buffer until the given depth (inclusive) * @returns {string} A slice of buffer until the given depth (inclusive)
*/ */
get(entry) { get(entry, scanState) {
let depth = entry.scanDepth ?? this.getDepth(); let depth = entry.scanDepth ?? this.getDepth();
if (depth <= this.#startDepth) { if (depth <= this.#startDepth) {
return ''; return '';
@@ -226,7 +232,12 @@ class WorldInfoBuffer {
let result = this.#depthBuffer.slice(this.#startDepth, depth).join('\n'); let result = this.#depthBuffer.slice(this.#startDepth, depth).join('\n');
if (this.#recurseBuffer.length > 0) { if (this.#injectBuffer.length > 0) {
result += '\n' + this.#injectBuffer.join('\n');
}
// Min activations should not include the recursion buffer
if (this.#recurseBuffer.length > 0 && scanState !== scan_state.MIN_ACTIVATIONS) {
result += '\n' + this.#recurseBuffer.join('\n'); result += '\n' + this.#recurseBuffer.join('\n');
} }
@@ -280,6 +291,14 @@ class WorldInfoBuffer {
this.#recurseBuffer.push(message); this.#recurseBuffer.push(message);
} }
/**
* Adds an injection to the buffer.
* @param {string} message The injection to add
*/
addInject(message) {
this.#injectBuffer.push(message);
}
/** /**
* Increments skew and sets startDepth to previous depth. * Increments skew and sets startDepth to previous depth.
*/ */
@@ -315,10 +334,11 @@ class WorldInfoBuffer {
/** /**
* Gets the match score for the given entry. * Gets the match score for the given entry.
* @param {WIScanEntry} entry Entry to check * @param {WIScanEntry} entry Entry to check
* @param {number} scanState The state of the scan
* @returns {number} The number of key activations for the given entry * @returns {number} The number of key activations for the given entry
*/ */
getScore(entry) { getScore(entry, scanState) {
const bufferState = this.get(entry); const bufferState = this.get(entry, scanState);
let numberOfPrimaryKeys = 0; let numberOfPrimaryKeys = 0;
let numberOfSecondaryKeys = 0; let numberOfSecondaryKeys = 0;
let primaryScore = 0; let primaryScore = 0;
@@ -3525,7 +3545,7 @@ async function checkWorldInfo(chat, maxContext, isDryRun) {
if (context.extensionPrompts[key]?.scan) { if (context.extensionPrompts[key]?.scan) {
const prompt = getExtensionPromptByName(key); const prompt = getExtensionPromptByName(key);
if (prompt) { if (prompt) {
buffer.addRecurse(prompt); buffer.addInject(prompt);
} }
} }
} }
@@ -3636,7 +3656,7 @@ async function checkWorldInfo(chat, maxContext, isDryRun) {
primary: for (let key of entry.key) { primary: for (let key of entry.key) {
const substituted = substituteParams(key); const substituted = substituteParams(key);
const textToScan = buffer.get(entry); const textToScan = buffer.get(entry, scanState);
if (substituted && buffer.matchKeys(textToScan, substituted.trim(), entry)) { if (substituted && buffer.matchKeys(textToScan, substituted.trim(), entry)) {
console.debug(`WI UID ${entry.uid} found by primary match: ${substituted}.`); console.debug(`WI UID ${entry.uid} found by primary match: ${substituted}.`);
@@ -3706,7 +3726,7 @@ async function checkWorldInfo(chat, maxContext, isDryRun) {
const textToScanTokens = await getTokenCountAsync(allActivatedText); const textToScanTokens = await getTokenCountAsync(allActivatedText);
const probabilityChecksBefore = failedProbabilityChecks.size; const probabilityChecksBefore = failedProbabilityChecks.size;
filterByInclusionGroups(newEntries, allActivatedEntries, buffer); filterByInclusionGroups(newEntries, allActivatedEntries, buffer, scanState);
console.debug('-- PROBABILITY CHECKS BEGIN --'); console.debug('-- PROBABILITY CHECKS BEGIN --');
for (const entry of newEntries) { for (const entry of newEntries) {
@@ -3858,8 +3878,9 @@ async function checkWorldInfo(chat, maxContext, isDryRun) {
* @param {Record<string, WIScanEntry[]>} groups The groups to filter * @param {Record<string, WIScanEntry[]>} groups The groups to filter
* @param {WorldInfoBuffer} buffer The buffer to use for scoring * @param {WorldInfoBuffer} buffer The buffer to use for scoring
* @param {(entry: WIScanEntry) => void} removeEntry The function to remove an entry * @param {(entry: WIScanEntry) => void} removeEntry The function to remove an entry
* @param {number} scanState The current scan state
*/ */
function filterGroupsByScoring(groups, buffer, removeEntry) { function filterGroupsByScoring(groups, buffer, removeEntry, scanState) {
for (const [key, group] of Object.entries(groups)) { for (const [key, group] of Object.entries(groups)) {
// Group scoring is disabled both globally and for the group entries // Group scoring is disabled both globally and for the group entries
if (!world_info_use_group_scoring && !group.some(x => x.useGroupScoring)) { if (!world_info_use_group_scoring && !group.some(x => x.useGroupScoring)) {
@@ -3867,7 +3888,7 @@ function filterGroupsByScoring(groups, buffer, removeEntry) {
continue; continue;
} }
const scores = group.map(entry => buffer.getScore(entry)); const scores = group.map(entry => buffer.getScore(entry, scanState));
const maxScore = Math.max(...scores); const maxScore = Math.max(...scores);
console.debug(`Group '${key}' max score: ${maxScore}`); console.debug(`Group '${key}' max score: ${maxScore}`);
//console.table(group.map((entry, i) => ({ uid: entry.uid, key: JSON.stringify(entry.key), score: scores[i] }))); //console.table(group.map((entry, i) => ({ uid: entry.uid, key: JSON.stringify(entry.key), score: scores[i] })));
@@ -3895,8 +3916,9 @@ function filterGroupsByScoring(groups, buffer, removeEntry) {
* @param {object[]} newEntries Entries activated on current recursion level * @param {object[]} newEntries Entries activated on current recursion level
* @param {Set<object>} allActivatedEntries Set of all activated entries * @param {Set<object>} allActivatedEntries Set of all activated entries
* @param {WorldInfoBuffer} buffer The buffer to use for scanning * @param {WorldInfoBuffer} buffer The buffer to use for scanning
* @param {number} scanState The current scan state
*/ */
function filterByInclusionGroups(newEntries, allActivatedEntries, buffer) { function filterByInclusionGroups(newEntries, allActivatedEntries, buffer, scanState) {
console.debug('-- INCLUSION GROUP CHECKS BEGIN --'); console.debug('-- INCLUSION GROUP CHECKS BEGIN --');
const grouped = newEntries.filter(x => x.group).reduce((acc, item) => { const grouped = newEntries.filter(x => x.group).reduce((acc, item) => {
item.group.split(/,\s*/).filter(x => x).forEach(group => { item.group.split(/,\s*/).filter(x => x).forEach(group => {
@@ -3925,7 +3947,7 @@ function filterByInclusionGroups(newEntries, allActivatedEntries, buffer) {
} }
} }
filterGroupsByScoring(grouped, buffer, removeEntry); filterGroupsByScoring(grouped, buffer, removeEntry, scanState);
for (const [key, group] of Object.entries(grouped)) { for (const [key, group] of Object.entries(grouped)) {
console.debug(`Checking inclusion group '${key}' with ${group.length} entries`, group); console.debug(`Checking inclusion group '${key}' with ${group.length} entries`, group);