Merge pull request #2468 from SillyTavern/wi-scan-state
Fix min activations for non-recursable entries
This commit is contained in:
commit
81f6520354
|
@ -50,6 +50,28 @@ const world_info_logic = {
|
||||||
AND_ALL: 3,
|
AND_ALL: 3,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @enum {number} Possible states of the WI evaluation
|
||||||
|
*/
|
||||||
|
const scan_state = {
|
||||||
|
/**
|
||||||
|
* The scan will be stopped.
|
||||||
|
*/
|
||||||
|
NONE: 0,
|
||||||
|
/**
|
||||||
|
* Initial state.
|
||||||
|
*/
|
||||||
|
INITIAL: 1,
|
||||||
|
/**
|
||||||
|
* The scan is triggered by a recursion step.
|
||||||
|
*/
|
||||||
|
RECURSION: 2,
|
||||||
|
/**
|
||||||
|
* The scan is triggered by a min activations depth skew.
|
||||||
|
*/
|
||||||
|
MIN_ACTIVATIONS: 2,
|
||||||
|
};
|
||||||
|
|
||||||
const WI_ENTRY_EDIT_TEMPLATE = $('#entry_edit_template .world_entry');
|
const WI_ENTRY_EDIT_TEMPLATE = $('#entry_edit_template .world_entry');
|
||||||
|
|
||||||
let world_info = {};
|
let world_info = {};
|
||||||
|
@ -135,6 +157,11 @@ class WorldInfoBuffer {
|
||||||
*/
|
*/
|
||||||
#recurseBuffer = [];
|
#recurseBuffer = [];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @type {string[]} Array of strings added by prompt injections that are valid for the current scan
|
||||||
|
*/
|
||||||
|
#injectBuffer = [];
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @type {number} The skew of the global scan depth. Used in "min activations"
|
* @type {number} The skew of the global scan depth. Used in "min activations"
|
||||||
*/
|
*/
|
||||||
|
@ -184,9 +211,10 @@ class WorldInfoBuffer {
|
||||||
/**
|
/**
|
||||||
* Gets all messages up to the given depth + recursion buffer.
|
* Gets all messages up to the given depth + recursion buffer.
|
||||||
* @param {WIScanEntry} entry The entry that triggered the scan
|
* @param {WIScanEntry} entry The entry that triggered the scan
|
||||||
|
* @param {number} scanState The state of the scan
|
||||||
* @returns {string} A slice of buffer until the given depth (inclusive)
|
* @returns {string} A slice of buffer until the given depth (inclusive)
|
||||||
*/
|
*/
|
||||||
get(entry) {
|
get(entry, scanState) {
|
||||||
let depth = entry.scanDepth ?? this.getDepth();
|
let depth = entry.scanDepth ?? this.getDepth();
|
||||||
if (depth <= this.#startDepth) {
|
if (depth <= this.#startDepth) {
|
||||||
return '';
|
return '';
|
||||||
|
@ -204,7 +232,12 @@ class WorldInfoBuffer {
|
||||||
|
|
||||||
let result = this.#depthBuffer.slice(this.#startDepth, depth).join('\n');
|
let result = this.#depthBuffer.slice(this.#startDepth, depth).join('\n');
|
||||||
|
|
||||||
if (this.#recurseBuffer.length > 0) {
|
if (this.#injectBuffer.length > 0) {
|
||||||
|
result += '\n' + this.#injectBuffer.join('\n');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Min activations should not include the recursion buffer
|
||||||
|
if (this.#recurseBuffer.length > 0 && scanState !== scan_state.MIN_ACTIVATIONS) {
|
||||||
result += '\n' + this.#recurseBuffer.join('\n');
|
result += '\n' + this.#recurseBuffer.join('\n');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -258,6 +291,14 @@ class WorldInfoBuffer {
|
||||||
this.#recurseBuffer.push(message);
|
this.#recurseBuffer.push(message);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds an injection to the buffer.
|
||||||
|
* @param {string} message The injection to add
|
||||||
|
*/
|
||||||
|
addInject(message) {
|
||||||
|
this.#injectBuffer.push(message);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Increments skew and sets startDepth to previous depth.
|
* Increments skew and sets startDepth to previous depth.
|
||||||
*/
|
*/
|
||||||
|
@ -293,10 +334,11 @@ class WorldInfoBuffer {
|
||||||
/**
|
/**
|
||||||
* Gets the match score for the given entry.
|
* Gets the match score for the given entry.
|
||||||
* @param {WIScanEntry} entry Entry to check
|
* @param {WIScanEntry} entry Entry to check
|
||||||
|
* @param {number} scanState The state of the scan
|
||||||
* @returns {number} The number of key activations for the given entry
|
* @returns {number} The number of key activations for the given entry
|
||||||
*/
|
*/
|
||||||
getScore(entry) {
|
getScore(entry, scanState) {
|
||||||
const bufferState = this.get(entry);
|
const bufferState = this.get(entry, scanState);
|
||||||
let numberOfPrimaryKeys = 0;
|
let numberOfPrimaryKeys = 0;
|
||||||
let numberOfSecondaryKeys = 0;
|
let numberOfSecondaryKeys = 0;
|
||||||
let primaryScore = 0;
|
let primaryScore = 0;
|
||||||
|
@ -3503,12 +3545,12 @@ async function checkWorldInfo(chat, maxContext, isDryRun) {
|
||||||
if (context.extensionPrompts[key]?.scan) {
|
if (context.extensionPrompts[key]?.scan) {
|
||||||
const prompt = getExtensionPromptByName(key);
|
const prompt = getExtensionPromptByName(key);
|
||||||
if (prompt) {
|
if (prompt) {
|
||||||
buffer.addRecurse(prompt);
|
buffer.addInject(prompt);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let needsToScan = true;
|
let scanState = scan_state.INITIAL;
|
||||||
let token_budget_overflowed = false;
|
let token_budget_overflowed = false;
|
||||||
let count = 0;
|
let count = 0;
|
||||||
let allActivatedEntries = new Set();
|
let allActivatedEntries = new Set();
|
||||||
|
@ -3532,8 +3574,9 @@ async function checkWorldInfo(chat, maxContext, isDryRun) {
|
||||||
return { worldInfoBefore: '', worldInfoAfter: '', WIDepthEntries: [], EMEntries: [], allActivatedEntries: new Set() };
|
return { worldInfoBefore: '', worldInfoAfter: '', WIDepthEntries: [], EMEntries: [], allActivatedEntries: new Set() };
|
||||||
}
|
}
|
||||||
|
|
||||||
while (needsToScan) {
|
while (scanState) {
|
||||||
// Track how many times the loop has run
|
// Track how many times the loop has run. May be useful for debugging.
|
||||||
|
// eslint-disable-next-line no-unused-vars
|
||||||
count++;
|
count++;
|
||||||
|
|
||||||
let activatedNow = new Set();
|
let activatedNow = new Set();
|
||||||
|
@ -3587,7 +3630,18 @@ async function checkWorldInfo(chat, maxContext, isDryRun) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (allActivatedEntries.has(entry) || entry.disable == true || (count > 1 && world_info_recursive && entry.excludeRecursion) || (count == 1 && entry.delayUntilRecursion)) {
|
if (allActivatedEntries.has(entry) || entry.disable == true) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only use checks for recursion flags if the scan step was activated by recursion
|
||||||
|
if (scanState !== scan_state.RECURSION && entry.delayUntilRecursion) {
|
||||||
|
console.debug(`WI entry ${entry.uid} suppressed by delay until recursion`, entry);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (scanState === scan_state.RECURSION && world_info_recursive && entry.excludeRecursion) {
|
||||||
|
console.debug(`WI entry ${entry.uid} suppressed by exclude recursion`, entry);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3602,7 +3656,7 @@ async function checkWorldInfo(chat, maxContext, isDryRun) {
|
||||||
|
|
||||||
primary: for (let key of entry.key) {
|
primary: for (let key of entry.key) {
|
||||||
const substituted = substituteParams(key);
|
const substituted = substituteParams(key);
|
||||||
const textToScan = buffer.get(entry);
|
const textToScan = buffer.get(entry, scanState);
|
||||||
|
|
||||||
if (substituted && buffer.matchKeys(textToScan, substituted.trim(), entry)) {
|
if (substituted && buffer.matchKeys(textToScan, substituted.trim(), entry)) {
|
||||||
console.debug(`WI UID ${entry.uid} found by primary match: ${substituted}.`);
|
console.debug(`WI UID ${entry.uid} found by primary match: ${substituted}.`);
|
||||||
|
@ -3665,14 +3719,14 @@ async function checkWorldInfo(chat, maxContext, isDryRun) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
needsToScan = world_info_recursive && activatedNow.size > 0;
|
scanState = world_info_recursive && activatedNow.size > 0 ? scan_state.RECURSION : scan_state.NONE;
|
||||||
const newEntries = [...activatedNow]
|
const newEntries = [...activatedNow]
|
||||||
.sort((a, b) => sortedEntries.indexOf(a) - sortedEntries.indexOf(b));
|
.sort((a, b) => sortedEntries.indexOf(a) - sortedEntries.indexOf(b));
|
||||||
let newContent = '';
|
let newContent = '';
|
||||||
const textToScanTokens = await getTokenCountAsync(allActivatedText);
|
const textToScanTokens = await getTokenCountAsync(allActivatedText);
|
||||||
const probabilityChecksBefore = failedProbabilityChecks.size;
|
const probabilityChecksBefore = failedProbabilityChecks.size;
|
||||||
|
|
||||||
filterByInclusionGroups(newEntries, allActivatedEntries, buffer);
|
filterByInclusionGroups(newEntries, allActivatedEntries, buffer, scanState);
|
||||||
|
|
||||||
console.debug('-- PROBABILITY CHECKS BEGIN --');
|
console.debug('-- PROBABILITY CHECKS BEGIN --');
|
||||||
for (const entry of newEntries) {
|
for (const entry of newEntries) {
|
||||||
|
@ -3697,7 +3751,7 @@ async function checkWorldInfo(chat, maxContext, isDryRun) {
|
||||||
console.log('Alerting');
|
console.log('Alerting');
|
||||||
toastr.warning(`World info budget reached after ${allActivatedEntries.size} entries.`, 'World Info');
|
toastr.warning(`World info budget reached after ${allActivatedEntries.size} entries.`, 'World Info');
|
||||||
}
|
}
|
||||||
needsToScan = false;
|
scanState = scan_state.NONE;
|
||||||
token_budget_overflowed = true;
|
token_budget_overflowed = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -3710,15 +3764,15 @@ async function checkWorldInfo(chat, maxContext, isDryRun) {
|
||||||
|
|
||||||
if ((probabilityChecksAfter - probabilityChecksBefore) === activatedNow.size) {
|
if ((probabilityChecksAfter - probabilityChecksBefore) === activatedNow.size) {
|
||||||
console.debug('WI probability checks failed for all activated entries, stopping');
|
console.debug('WI probability checks failed for all activated entries, stopping');
|
||||||
needsToScan = false;
|
scanState = scan_state.NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (newEntries.length === 0) {
|
if (newEntries.length === 0) {
|
||||||
console.debug('No new entries activated, stopping');
|
console.debug('No new entries activated, stopping');
|
||||||
needsToScan = false;
|
scanState = scan_state.NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (needsToScan) {
|
if (scanState) {
|
||||||
const text = newEntries
|
const text = newEntries
|
||||||
.filter(x => !failedProbabilityChecks.has(x))
|
.filter(x => !failedProbabilityChecks.has(x))
|
||||||
.filter(x => !x.preventRecursion)
|
.filter(x => !x.preventRecursion)
|
||||||
|
@ -3728,7 +3782,7 @@ async function checkWorldInfo(chat, maxContext, isDryRun) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// world_info_min_activations
|
// world_info_min_activations
|
||||||
if (!needsToScan && !token_budget_overflowed) {
|
if (!scanState && !token_budget_overflowed) {
|
||||||
if (world_info_min_activations > 0 && (allActivatedEntries.size < world_info_min_activations)) {
|
if (world_info_min_activations > 0 && (allActivatedEntries.size < world_info_min_activations)) {
|
||||||
let over_max = (
|
let over_max = (
|
||||||
world_info_min_activations_depth_max > 0 &&
|
world_info_min_activations_depth_max > 0 &&
|
||||||
|
@ -3736,7 +3790,7 @@ async function checkWorldInfo(chat, maxContext, isDryRun) {
|
||||||
) || (buffer.getDepth() > chat.length);
|
) || (buffer.getDepth() > chat.length);
|
||||||
|
|
||||||
if (!over_max) {
|
if (!over_max) {
|
||||||
needsToScan = true; // loop
|
scanState = scan_state.MIN_ACTIVATIONS; // loop
|
||||||
buffer.advanceScanPosition();
|
buffer.advanceScanPosition();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3824,8 +3878,9 @@ async function checkWorldInfo(chat, maxContext, isDryRun) {
|
||||||
* @param {Record<string, WIScanEntry[]>} groups The groups to filter
|
* @param {Record<string, WIScanEntry[]>} groups The groups to filter
|
||||||
* @param {WorldInfoBuffer} buffer The buffer to use for scoring
|
* @param {WorldInfoBuffer} buffer The buffer to use for scoring
|
||||||
* @param {(entry: WIScanEntry) => void} removeEntry The function to remove an entry
|
* @param {(entry: WIScanEntry) => void} removeEntry The function to remove an entry
|
||||||
|
* @param {number} scanState The current scan state
|
||||||
*/
|
*/
|
||||||
function filterGroupsByScoring(groups, buffer, removeEntry) {
|
function filterGroupsByScoring(groups, buffer, removeEntry, scanState) {
|
||||||
for (const [key, group] of Object.entries(groups)) {
|
for (const [key, group] of Object.entries(groups)) {
|
||||||
// Group scoring is disabled both globally and for the group entries
|
// Group scoring is disabled both globally and for the group entries
|
||||||
if (!world_info_use_group_scoring && !group.some(x => x.useGroupScoring)) {
|
if (!world_info_use_group_scoring && !group.some(x => x.useGroupScoring)) {
|
||||||
|
@ -3833,7 +3888,7 @@ function filterGroupsByScoring(groups, buffer, removeEntry) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const scores = group.map(entry => buffer.getScore(entry));
|
const scores = group.map(entry => buffer.getScore(entry, scanState));
|
||||||
const maxScore = Math.max(...scores);
|
const maxScore = Math.max(...scores);
|
||||||
console.debug(`Group '${key}' max score: ${maxScore}`);
|
console.debug(`Group '${key}' max score: ${maxScore}`);
|
||||||
//console.table(group.map((entry, i) => ({ uid: entry.uid, key: JSON.stringify(entry.key), score: scores[i] })));
|
//console.table(group.map((entry, i) => ({ uid: entry.uid, key: JSON.stringify(entry.key), score: scores[i] })));
|
||||||
|
@ -3861,8 +3916,9 @@ function filterGroupsByScoring(groups, buffer, removeEntry) {
|
||||||
* @param {object[]} newEntries Entries activated on current recursion level
|
* @param {object[]} newEntries Entries activated on current recursion level
|
||||||
* @param {Set<object>} allActivatedEntries Set of all activated entries
|
* @param {Set<object>} allActivatedEntries Set of all activated entries
|
||||||
* @param {WorldInfoBuffer} buffer The buffer to use for scanning
|
* @param {WorldInfoBuffer} buffer The buffer to use for scanning
|
||||||
|
* @param {number} scanState The current scan state
|
||||||
*/
|
*/
|
||||||
function filterByInclusionGroups(newEntries, allActivatedEntries, buffer) {
|
function filterByInclusionGroups(newEntries, allActivatedEntries, buffer, scanState) {
|
||||||
console.debug('-- INCLUSION GROUP CHECKS BEGIN --');
|
console.debug('-- INCLUSION GROUP CHECKS BEGIN --');
|
||||||
const grouped = newEntries.filter(x => x.group).reduce((acc, item) => {
|
const grouped = newEntries.filter(x => x.group).reduce((acc, item) => {
|
||||||
item.group.split(/,\s*/).filter(x => x).forEach(group => {
|
item.group.split(/,\s*/).filter(x => x).forEach(group => {
|
||||||
|
@ -3891,7 +3947,7 @@ function filterByInclusionGroups(newEntries, allActivatedEntries, buffer) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
filterGroupsByScoring(grouped, buffer, removeEntry);
|
filterGroupsByScoring(grouped, buffer, removeEntry, scanState);
|
||||||
|
|
||||||
for (const [key, group] of Object.entries(grouped)) {
|
for (const [key, group] of Object.entries(grouped)) {
|
||||||
console.debug(`Checking inclusion group '${key}' with ${group.length} entries`, group);
|
console.debug(`Checking inclusion group '${key}' with ${group.length} entries`, group);
|
||||||
|
|
Loading…
Reference in New Issue