Add WI group scoring mode

This commit is contained in:
Cohee 2024-05-04 23:51:28 +03:00
parent b13434c505
commit 2bf9869e5f
2 changed files with 114 additions and 4 deletions

View File

@ -3398,6 +3398,12 @@
Match whole words
</small>
</label>
<label title="Only the entries with the most number of key matches will be selected for Inclusion Group filtering" data-i18n="[title]Only the entries with the most number of key matches will be selected for Inclusion Group filtering" class="checkbox_label flex1">
<input id="world_info_use_group_scoring" type="checkbox" />
<small data-i18n="Use Group Scoring" class="whitespacenowrap flex1">
Use Group Scoring
</small>
</label>
<label title="Alert if your world info is greater than the allocated budget." data-i18n="[title]Alert if your world info is greater than the allocated budget." class="checkbox_label flex1">
<input id="world_info_overflow_alert" type="checkbox" />
<small data-i18n="Alert On Overflow" class="whitespacenowrap flex1">

View File

@ -57,6 +57,7 @@ let world_info_recursive = false;
let world_info_overflow_alert = false;
let world_info_case_sensitive = false;
let world_info_match_whole_words = false;
let world_info_use_group_scoring = false;
let world_info_character_strategy = world_info_insertion_strategy.character_first;
let world_info_budget_cap = 0;
const saveWorldDebounced = debounce(async (name, data) => await _save(name, data), debounce_timeout.relaxed);
@ -80,7 +81,16 @@ const MAX_SCAN_DEPTH = 1000;
*/
class WorldInfoBuffer {
// Typedef area
/** @typedef {{scanDepth?: number, caseSensitive?: boolean, matchWholeWords?: boolean}} WIScanEntry The entry that triggered the scan */
/**
* @typedef {object} WIScanEntry The entry that triggered the scan
* @property {number} [scanDepth] The depth of the scan
* @property {boolean} [caseSensitive] If the scan is case sensitive
* @property {boolean} [matchWholeWords] If the scan should match whole words
* @property {number} [uid] The UID of the entry that triggered the scan
* @property {string[]} [key] The primary keys to scan for
* @property {string[]} [keysecondary] The secondary keys to scan for
* @property {number} [selectiveLogic] The logic to use for selective activation
*/
// End typedef area
/**
@ -244,6 +254,58 @@ class WorldInfoBuffer {
cleanExternalActivations() {
WorldInfoBuffer.externalActivations.splice(0, WorldInfoBuffer.externalActivations.length);
}
/**
* Gets the match score for the given entry.
* @param {WIScanEntry} entry Entry to check
* @returns {number} The number of key activations for the given entry
*/
getScore(entry) {
const bufferState = this.get(entry);
let numberOfPrimaryKeys = 0;
let numberOfSecondaryKeys = 0;
let primaryScore = 0;
let secondaryScore = 0;
// Increment score for every key found in the buffer
if (Array.isArray(entry.key)) {
numberOfPrimaryKeys = entry.key.length;
for (const key of entry.key) {
if (this.matchKeys(bufferState, key, entry)) {
primaryScore++;
}
}
}
// Increment score for every secondary key found in the buffer
if (Array.isArray(entry.keysecondary)) {
numberOfSecondaryKeys = entry.keysecondary.length;
for (const key of entry.keysecondary) {
if (this.matchKeys(bufferState, key, entry)) {
secondaryScore++;
}
}
}
// No keys == no score
if (!numberOfPrimaryKeys) {
return 0;
}
// Only positive logic influences the score
if (numberOfSecondaryKeys > 0) {
switch (entry.selectiveLogic) {
// AND_ANY: Add both scores
case world_info_logic.AND_ANY:
return primaryScore + secondaryScore;
// AND_ALL: Add both scores if all secondary keys are found, otherwise only primary score
case world_info_logic.AND_ALL:
return secondaryScore === numberOfSecondaryKeys ? primaryScore + secondaryScore : primaryScore;
}
}
return primaryScore;
}
}
export function getWorldInfoSettings() {
@ -259,6 +321,7 @@ export function getWorldInfoSettings() {
world_info_match_whole_words,
world_info_character_strategy,
world_info_budget_cap,
world_info_use_group_scoring,
};
}
@ -322,12 +385,18 @@ function setWorldInfoSettings(settings, data) {
world_info_character_strategy = Number(settings.world_info_character_strategy);
if (settings.world_info_budget_cap !== undefined)
world_info_budget_cap = Number(settings.world_info_budget_cap);
if (settings.world_info_use_group_scoring !== undefined)
world_info_use_group_scoring = Boolean(settings.world_info_use_group_scoring);
// Migrate old settings
if (world_info_budget > 100) {
world_info_budget = 25;
}
if (world_info_use_group_scoring === undefined) {
world_info_use_group_scoring = false;
}
// Reset selected world from old string and delete old keys
// TODO: Remove next release
const existingWorldInfo = settings.world_info;
@ -357,6 +426,7 @@ function setWorldInfoSettings(settings, data) {
$('#world_info_overflow_alert').prop('checked', world_info_overflow_alert);
$('#world_info_case_sensitive').prop('checked', world_info_case_sensitive);
$('#world_info_match_whole_words').prop('checked', world_info_match_whole_words);
$('#world_info_use_group_scoring').prop('checked', world_info_use_group_scoring);
$(`#world_info_character_strategy option[value='${world_info_character_strategy}']`).prop('selected', true);
$('#world_info_character_strategy').val(world_info_character_strategy);
@ -786,7 +856,7 @@ function displayWorldEntries(name, data, navigation = navigation_option.none) {
// Apply the filter and do the chosen sorting
entriesArray = worldInfoFilter.applyFilters(entriesArray);
entriesArray = sortEntries(entriesArray)
entriesArray = sortEntries(entriesArray);
// Run the callback for printing this
typeof callback === 'function' && callback(entriesArray);
@ -2332,7 +2402,7 @@ async function checkWorldInfo(chat, maxContext) {
const textToScanTokens = await getTokenCountAsync(allActivatedText);
const probabilityChecksBefore = failedProbabilityChecks.size;
filterByInclusionGroups(newEntries, allActivatedEntries);
filterByInclusionGroups(newEntries, allActivatedEntries, buffer);
console.debug('-- PROBABILITY CHECKS BEGIN --');
for (const entry of newEntries) {
@ -2452,12 +2522,36 @@ async function checkWorldInfo(chat, maxContext) {
return { worldInfoBefore, worldInfoAfter, WIDepthEntries, allActivatedEntries };
}
/**
* Only leaves entries with the highest key matching score in each group.
* @param {Record<string, WIScanEntry[]>} groups The groups to filter
* @param {WorldInfoBuffer} buffer The buffer to use for scoring
*/
function filterGroupsByScoring(groups, buffer) {
for (const [key, group] of Object.entries(groups)) {
const scores = group.map(entry => buffer.getScore(entry));
const maxScore = Math.max(...scores);
console.debug(`Group '${key}' max score: ${maxScore}`);
//console.table(group.map((entry, i) => ({ uid: entry.uid, key: JSON.stringify(entry.key), score: scores[i] })));
for (let i = 0; i < group.length; i++) {
if (scores[i] < maxScore) {
console.debug(`Removing score loser from inclusion group '${key}' entry '${group[i].uid}'`, group[i]);
group.splice(i, 1);
scores.splice(i, 1);
i--;
}
}
}
}
/**
* Filters entries by inclusion groups.
* @param {object[]} newEntries Entries activated on current recursion level
* @param {Set<object>} allActivatedEntries Set of all activated entries
* @param {WorldInfoBuffer} buffer The buffer to use for scanning
*/
function filterByInclusionGroups(newEntries, allActivatedEntries) {
function filterByInclusionGroups(newEntries, allActivatedEntries, buffer) {
console.debug('-- INCLUSION GROUP CHECKS BEGIN --');
const grouped = newEntries.filter(x => x.group).reduce((acc, item) => {
if (!acc[item.group]) {
@ -2472,6 +2566,11 @@ function filterByInclusionGroups(newEntries, allActivatedEntries) {
return;
}
if (world_info_use_group_scoring) {
console.debug('Using group scoring');
filterGroupsByScoring(grouped, buffer);
}
const removeEntry = (entry) => newEntries.splice(newEntries.indexOf(entry), 1);
function removeAllBut(group, chosen, logging = true) {
for (const entry of group) {
@ -3058,6 +3157,11 @@ jQuery(() => {
saveSettingsDebounced();
});
$('#world_info_use_group_scoring').on('change', function () {
world_info_use_group_scoring = !!$(this).prop('checked');
saveSettingsDebounced();
});
$('#world_info_budget_cap').on('input', function () {
world_info_budget_cap = Number($(this).val());
$('#world_info_budget_cap_counter').val(world_info_budget_cap);