diff --git a/public/scripts/world-info.js b/public/scripts/world-info.js index 2150fa4ab..b38ae1fa4 100644 --- a/public/scripts/world-info.js +++ b/public/scripts/world-info.js @@ -50,6 +50,28 @@ const world_info_logic = { AND_ALL: 3, }; +/** + * @enum {number} Possible states of the WI evaluation + */ +const scan_state = { + /** + * The scan will be stopped. + */ + NONE: 0, + /** + * Initial state. + */ + INITIAL: 1, + /** + * The scan is triggered by a recursion step. + */ + RECURSION: 2, + /** + * The scan is triggered by a min activations depth skew. + */ + MIN_ACTIVATIONS: 2, +}; + const WI_ENTRY_EDIT_TEMPLATE = $('#entry_edit_template .world_entry'); let world_info = {}; @@ -135,6 +157,11 @@ class WorldInfoBuffer { */ #recurseBuffer = []; + /** + * @type {string[]} Array of strings added by prompt injections that are valid for the current scan + */ + #injectBuffer = []; + /** * @type {number} The skew of the global scan depth. Used in "min activations" */ @@ -184,9 +211,10 @@ class WorldInfoBuffer { /** * Gets all messages up to the given depth + recursion buffer. * @param {WIScanEntry} entry The entry that triggered the scan + * @param {number} scanState The state of the scan * @returns {string} A slice of buffer until the given depth (inclusive) */ - get(entry) { + get(entry, scanState) { let depth = entry.scanDepth ?? this.getDepth(); if (depth <= this.#startDepth) { return ''; @@ -204,7 +232,12 @@ class WorldInfoBuffer { let result = this.#depthBuffer.slice(this.#startDepth, depth).join('\n'); - if (this.#recurseBuffer.length > 0) { + if (this.#injectBuffer.length > 0) { + result += '\n' + this.#injectBuffer.join('\n'); + } + + // Min activations should not include the recursion buffer + if (this.#recurseBuffer.length > 0 && scanState !== scan_state.MIN_ACTIVATIONS) { result += '\n' + this.#recurseBuffer.join('\n'); } @@ -258,6 +291,14 @@ class WorldInfoBuffer { this.#recurseBuffer.push(message); } + /** + * Adds an injection to the buffer. + * @param {string} message The injection to add + */ + addInject(message) { + this.#injectBuffer.push(message); + } + /** * Increments skew and sets startDepth to previous depth. */ @@ -293,10 +334,11 @@ class WorldInfoBuffer { /** * Gets the match score for the given entry. * @param {WIScanEntry} entry Entry to check + * @param {number} scanState The state of the scan * @returns {number} The number of key activations for the given entry */ - getScore(entry) { - const bufferState = this.get(entry); + getScore(entry, scanState) { + const bufferState = this.get(entry, scanState); let numberOfPrimaryKeys = 0; let numberOfSecondaryKeys = 0; let primaryScore = 0; @@ -3503,12 +3545,12 @@ async function checkWorldInfo(chat, maxContext, isDryRun) { if (context.extensionPrompts[key]?.scan) { const prompt = getExtensionPromptByName(key); if (prompt) { - buffer.addRecurse(prompt); + buffer.addInject(prompt); } } } - let needsToScan = true; + let scanState = scan_state.INITIAL; let token_budget_overflowed = false; let count = 0; let allActivatedEntries = new Set(); @@ -3532,8 +3574,9 @@ async function checkWorldInfo(chat, maxContext, isDryRun) { return { worldInfoBefore: '', worldInfoAfter: '', WIDepthEntries: [], EMEntries: [], allActivatedEntries: new Set() }; } - while (needsToScan) { - // Track how many times the loop has run + while (scanState) { + // Track how many times the loop has run. May be useful for debugging. + // eslint-disable-next-line no-unused-vars count++; let activatedNow = new Set(); @@ -3587,7 +3630,18 @@ async function checkWorldInfo(chat, maxContext, isDryRun) { continue; } - if (allActivatedEntries.has(entry) || entry.disable == true || (count > 1 && world_info_recursive && entry.excludeRecursion) || (count == 1 && entry.delayUntilRecursion)) { + if (allActivatedEntries.has(entry) || entry.disable == true) { + continue; + } + + // Only use checks for recursion flags if the scan step was activated by recursion + if (scanState !== scan_state.RECURSION && entry.delayUntilRecursion) { + console.debug(`WI entry ${entry.uid} suppressed by delay until recursion`, entry); + continue; + } + + if (scanState === scan_state.RECURSION && world_info_recursive && entry.excludeRecursion) { + console.debug(`WI entry ${entry.uid} suppressed by exclude recursion`, entry); continue; } @@ -3602,7 +3656,7 @@ async function checkWorldInfo(chat, maxContext, isDryRun) { primary: for (let key of entry.key) { const substituted = substituteParams(key); - const textToScan = buffer.get(entry); + const textToScan = buffer.get(entry, scanState); if (substituted && buffer.matchKeys(textToScan, substituted.trim(), entry)) { console.debug(`WI UID ${entry.uid} found by primary match: ${substituted}.`); @@ -3665,14 +3719,14 @@ async function checkWorldInfo(chat, maxContext, isDryRun) { } } - needsToScan = world_info_recursive && activatedNow.size > 0; + scanState = world_info_recursive && activatedNow.size > 0 ? scan_state.RECURSION : scan_state.NONE; const newEntries = [...activatedNow] .sort((a, b) => sortedEntries.indexOf(a) - sortedEntries.indexOf(b)); let newContent = ''; const textToScanTokens = await getTokenCountAsync(allActivatedText); const probabilityChecksBefore = failedProbabilityChecks.size; - filterByInclusionGroups(newEntries, allActivatedEntries, buffer); + filterByInclusionGroups(newEntries, allActivatedEntries, buffer, scanState); console.debug('-- PROBABILITY CHECKS BEGIN --'); for (const entry of newEntries) { @@ -3697,7 +3751,7 @@ async function checkWorldInfo(chat, maxContext, isDryRun) { console.log('Alerting'); toastr.warning(`World info budget reached after ${allActivatedEntries.size} entries.`, 'World Info'); } - needsToScan = false; + scanState = scan_state.NONE; token_budget_overflowed = true; break; } @@ -3710,15 +3764,15 @@ async function checkWorldInfo(chat, maxContext, isDryRun) { if ((probabilityChecksAfter - probabilityChecksBefore) === activatedNow.size) { console.debug('WI probability checks failed for all activated entries, stopping'); - needsToScan = false; + scanState = scan_state.NONE; } if (newEntries.length === 0) { console.debug('No new entries activated, stopping'); - needsToScan = false; + scanState = scan_state.NONE; } - if (needsToScan) { + if (scanState) { const text = newEntries .filter(x => !failedProbabilityChecks.has(x)) .filter(x => !x.preventRecursion) @@ -3728,7 +3782,7 @@ async function checkWorldInfo(chat, maxContext, isDryRun) { } // world_info_min_activations - if (!needsToScan && !token_budget_overflowed) { + if (!scanState && !token_budget_overflowed) { if (world_info_min_activations > 0 && (allActivatedEntries.size < world_info_min_activations)) { let over_max = ( world_info_min_activations_depth_max > 0 && @@ -3736,7 +3790,7 @@ async function checkWorldInfo(chat, maxContext, isDryRun) { ) || (buffer.getDepth() > chat.length); if (!over_max) { - needsToScan = true; // loop + scanState = scan_state.MIN_ACTIVATIONS; // loop buffer.advanceScanPosition(); } } @@ -3824,8 +3878,9 @@ async function checkWorldInfo(chat, maxContext, isDryRun) { * @param {Record} groups The groups to filter * @param {WorldInfoBuffer} buffer The buffer to use for scoring * @param {(entry: WIScanEntry) => void} removeEntry The function to remove an entry + * @param {number} scanState The current scan state */ -function filterGroupsByScoring(groups, buffer, removeEntry) { +function filterGroupsByScoring(groups, buffer, removeEntry, scanState) { for (const [key, group] of Object.entries(groups)) { // Group scoring is disabled both globally and for the group entries if (!world_info_use_group_scoring && !group.some(x => x.useGroupScoring)) { @@ -3833,7 +3888,7 @@ function filterGroupsByScoring(groups, buffer, removeEntry) { continue; } - const scores = group.map(entry => buffer.getScore(entry)); + const scores = group.map(entry => buffer.getScore(entry, scanState)); const maxScore = Math.max(...scores); console.debug(`Group '${key}' max score: ${maxScore}`); //console.table(group.map((entry, i) => ({ uid: entry.uid, key: JSON.stringify(entry.key), score: scores[i] }))); @@ -3861,8 +3916,9 @@ function filterGroupsByScoring(groups, buffer, removeEntry) { * @param {object[]} newEntries Entries activated on current recursion level * @param {Set} allActivatedEntries Set of all activated entries * @param {WorldInfoBuffer} buffer The buffer to use for scanning + * @param {number} scanState The current scan state */ -function filterByInclusionGroups(newEntries, allActivatedEntries, buffer) { +function filterByInclusionGroups(newEntries, allActivatedEntries, buffer, scanState) { console.debug('-- INCLUSION GROUP CHECKS BEGIN --'); const grouped = newEntries.filter(x => x.group).reduce((acc, item) => { item.group.split(/,\s*/).filter(x => x).forEach(group => { @@ -3891,7 +3947,7 @@ function filterByInclusionGroups(newEntries, allActivatedEntries, buffer) { } } - filterGroupsByScoring(grouped, buffer, removeEntry); + filterGroupsByScoring(grouped, buffer, removeEntry, scanState); for (const [key, group] of Object.entries(grouped)) { console.debug(`Checking inclusion group '${key}' with ${group.length} entries`, group);