Per-entry group scoring

This commit is contained in:
Cohee 2024-05-05 00:42:33 +03:00
parent 2bf9869e5f
commit 39a54d158d
3 changed files with 46 additions and 6 deletions

View File

@ -5121,6 +5121,14 @@
<option value="false" data-i18n="No">No</option> <option value="false" data-i18n="No">No</option>
</select> </select>
</div> </div>
<div class="world_entry_form_control flex1">
<small class="textAlignCenter" data-i18n="Use Group Scoring">Use Group Scoring</small>
<select name="useGroupScoring" class="text_pole widthNatural margin0">
<option value="null" data-i18n="Use global setting">Use global setting</option>
<option value="true" data-i18n="Yes">Yes</option>
<option value="false" data-i18n="No">No</option>
</select>
</div>
<div class="world_entry_form_control flex1" title="Can be used to automatically activate Quick Replies" data-i18n="[title]Can be used to automatically activate Quick Replies"> <div class="world_entry_form_control flex1" title="Can be used to automatically activate Quick Replies" data-i18n="[title]Can be used to automatically activate Quick Replies">
<small class="textAlignCenter" data-i18n="Automation ID">Automation ID</small> <small class="textAlignCenter" data-i18n="Automation ID">Automation ID</small>
<input class="text_pole margin0" name="automationId" type="text" placeholder="( None )" data-i18n="[placeholder]( None )"> <input class="text_pole margin0" name="automationId" type="text" placeholder="( None )" data-i18n="[placeholder]( None )">

View File

@ -86,6 +86,7 @@ class WorldInfoBuffer {
* @property {number} [scanDepth] The depth of the scan * @property {number} [scanDepth] The depth of the scan
* @property {boolean} [caseSensitive] If the scan is case sensitive * @property {boolean} [caseSensitive] If the scan is case sensitive
* @property {boolean} [matchWholeWords] If the scan should match whole words * @property {boolean} [matchWholeWords] If the scan should match whole words
* @property {boolean} [useGroupScoring] If the scan should use group scoring
* @property {number} [uid] The UID of the entry that triggered the scan * @property {number} [uid] The UID of the entry that triggered the scan
* @property {string[]} [key] The primary keys to scan for * @property {string[]} [key] The primary keys to scan for
* @property {string[]} [keysecondary] The secondary keys to scan for * @property {string[]} [keysecondary] The secondary keys to scan for
@ -1086,6 +1087,7 @@ const originalDataKeyMap = {
'keysecondary': 'secondary_keys', 'keysecondary': 'secondary_keys',
'selective': 'selective', 'selective': 'selective',
'matchWholeWords': 'extensions.match_whole_words', 'matchWholeWords': 'extensions.match_whole_words',
'useGroupScoring': 'extensions.use_group_scoring',
'caseSensitive': 'extensions.case_sensitive', 'caseSensitive': 'extensions.case_sensitive',
'scanDepth': 'extensions.scan_depth', 'scanDepth': 'extensions.scan_depth',
'automationId': 'extensions.automation_id', 'automationId': 'extensions.automation_id',
@ -1779,6 +1781,19 @@ function getWorldEntry(name, data, entry) {
}); });
matchWholeWordsSelect.val((entry.matchWholeWords === null || entry.matchWholeWords === undefined) ? 'null' : entry.matchWholeWords ? 'true' : 'false').trigger('input'); matchWholeWordsSelect.val((entry.matchWholeWords === null || entry.matchWholeWords === undefined) ? 'null' : entry.matchWholeWords ? 'true' : 'false').trigger('input');
// use group scoring select
const useGroupScoringSelect = template.find('select[name="useGroupScoring"]');
useGroupScoringSelect.data('uid', entry.uid);
useGroupScoringSelect.on('input', function () {
const uid = $(this).data('uid');
const value = $(this).val();
data.entries[uid].useGroupScoring = value === 'null' ? null : value === 'true';
setOriginalDataValue(data, uid, 'extensions.use_group_scoring', data.entries[uid].useGroupScoring);
saveWorldInfo(name, data);
});
useGroupScoringSelect.val((entry.useGroupScoring === null || entry.useGroupScoring === undefined) ? 'null' : entry.useGroupScoring ? 'true' : 'false').trigger('input');
// automation id // automation id
const automationIdInput = template.find('input[name="automationId"]'); const automationIdInput = template.find('input[name="automationId"]');
automationIdInput.data('uid', entry.uid); automationIdInput.data('uid', entry.uid);
@ -1959,6 +1974,7 @@ const newEntryTemplate = {
scanDepth: null, scanDepth: null,
caseSensitive: null, caseSensitive: null,
matchWholeWords: null, matchWholeWords: null,
useGroupScoring: null,
automationId: '', automationId: '',
role: 0, role: 0,
}; };
@ -2526,19 +2542,33 @@ async function checkWorldInfo(chat, maxContext) {
* Only leaves entries with the highest key matching score in each group. * Only leaves entries with the highest key matching score in each group.
* @param {Record<string, WIScanEntry[]>} groups The groups to filter * @param {Record<string, WIScanEntry[]>} groups The groups to filter
* @param {WorldInfoBuffer} buffer The buffer to use for scoring * @param {WorldInfoBuffer} buffer The buffer to use for scoring
* @param {(entry: WIScanEntry) => void} removeEntry The function to remove an entry
*/ */
function filterGroupsByScoring(groups, buffer) { function filterGroupsByScoring(groups, buffer, removeEntry) {
for (const [key, group] of Object.entries(groups)) { for (const [key, group] of Object.entries(groups)) {
// Group scoring is disabled both globally and for the group entries
if (!world_info_use_group_scoring && !group.some(x => x.useGroupScoring)) {
console.debug(`Skipping group scoring for group '${key}'`);
continue;
}
const scores = group.map(entry => buffer.getScore(entry)); const scores = group.map(entry => buffer.getScore(entry));
const maxScore = Math.max(...scores); const maxScore = Math.max(...scores);
console.debug(`Group '${key}' max score: ${maxScore}`); console.debug(`Group '${key}' max score: ${maxScore}`);
//console.table(group.map((entry, i) => ({ uid: entry.uid, key: JSON.stringify(entry.key), score: scores[i] }))); //console.table(group.map((entry, i) => ({ uid: entry.uid, key: JSON.stringify(entry.key), score: scores[i] })));
for (let i = 0; i < group.length; i++) { for (let i = 0; i < group.length; i++) {
const isScored = group[i].useGroupScoring ?? world_info_use_group_scoring;
if (!isScored) {
continue;
}
if (scores[i] < maxScore) { if (scores[i] < maxScore) {
console.debug(`Removing score loser from inclusion group '${key}' entry '${group[i].uid}'`, group[i]); console.debug(`Removing score loser from inclusion group '${key}' entry '${group[i].uid}'`, group[i]);
group.splice(i, 1); group.splice(i, 1);
scores.splice(i, 1); scores.splice(i, 1);
removeEntry(group[i]);
i--; i--;
} }
} }
@ -2566,11 +2596,6 @@ function filterByInclusionGroups(newEntries, allActivatedEntries, buffer) {
return; return;
} }
if (world_info_use_group_scoring) {
console.debug('Using group scoring');
filterGroupsByScoring(grouped, buffer);
}
const removeEntry = (entry) => newEntries.splice(newEntries.indexOf(entry), 1); const removeEntry = (entry) => newEntries.splice(newEntries.indexOf(entry), 1);
function removeAllBut(group, chosen, logging = true) { function removeAllBut(group, chosen, logging = true) {
for (const entry of group) { for (const entry of group) {
@ -2583,6 +2608,8 @@ function filterByInclusionGroups(newEntries, allActivatedEntries, buffer) {
} }
} }
filterGroupsByScoring(grouped, buffer, removeEntry);
for (const [key, group] of Object.entries(grouped)) { for (const [key, group] of Object.entries(grouped)) {
console.debug(`Checking inclusion group '${key}' with ${group.length} entries`, group); console.debug(`Checking inclusion group '${key}' with ${group.length} entries`, group);
@ -2660,6 +2687,7 @@ function convertAgnaiMemoryBook(inputObj) {
scanDepth: null, scanDepth: null,
caseSensitive: null, caseSensitive: null,
matchWholeWords: null, matchWholeWords: null,
useGroupScoring: null,
automationId: '', automationId: '',
role: extension_prompt_roles.SYSTEM, role: extension_prompt_roles.SYSTEM,
}; };
@ -2696,6 +2724,7 @@ function convertRisuLorebook(inputObj) {
scanDepth: null, scanDepth: null,
caseSensitive: null, caseSensitive: null,
matchWholeWords: null, matchWholeWords: null,
useGroupScoring: null,
automationId: '', automationId: '',
role: extension_prompt_roles.SYSTEM, role: extension_prompt_roles.SYSTEM,
}; };
@ -2737,6 +2766,7 @@ function convertNovelLorebook(inputObj) {
scanDepth: null, scanDepth: null,
caseSensitive: null, caseSensitive: null,
matchWholeWords: null, matchWholeWords: null,
useGroupScoring: null,
automationId: '', automationId: '',
role: extension_prompt_roles.SYSTEM, role: extension_prompt_roles.SYSTEM,
}; };
@ -2779,6 +2809,7 @@ function convertCharacterBook(characterBook) {
scanDepth: entry.extensions?.scan_depth ?? null, scanDepth: entry.extensions?.scan_depth ?? null,
caseSensitive: entry.extensions?.case_sensitive ?? null, caseSensitive: entry.extensions?.case_sensitive ?? null,
matchWholeWords: entry.extensions?.match_whole_words ?? null, matchWholeWords: entry.extensions?.match_whole_words ?? null,
useGroupScoring: entry.extensions?.use_group_scoring ?? null,
automationId: entry.extensions?.automation_id ?? '', automationId: entry.extensions?.automation_id ?? '',
role: entry.extensions?.role ?? extension_prompt_roles.SYSTEM, role: entry.extensions?.role ?? extension_prompt_roles.SYSTEM,
vectorized: entry.extensions?.vectorized ?? false, vectorized: entry.extensions?.vectorized ?? false,

View File

@ -437,6 +437,7 @@ function convertWorldInfoToCharacterBook(name, entries) {
prevent_recursion: entry.preventRecursion ?? false, prevent_recursion: entry.preventRecursion ?? false,
scan_depth: entry.scanDepth ?? null, scan_depth: entry.scanDepth ?? null,
match_whole_words: entry.matchWholeWords ?? null, match_whole_words: entry.matchWholeWords ?? null,
use_group_scoring: entry.useGroupScoring ?? false,
case_sensitive: entry.caseSensitive ?? null, case_sensitive: entry.caseSensitive ?? null,
automation_id: entry.automationId ?? '', automation_id: entry.automationId ?? '',
role: entry.role ?? 0, role: entry.role ?? 0,