mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-06-05 21:59:27 +02:00
Change insertion strategy to an extension block
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
import { eventSource, event_types, getCurrentChatId, getRequestHeaders, saveSettingsDebounced } from "../../../script.js";
|
||||
import { eventSource, event_types, extension_prompt_types, getCurrentChatId, getRequestHeaders, saveSettingsDebounced, setExtensionPrompt } from "../../../script.js";
|
||||
import { ModuleWorkerWrapper, extension_settings, getContext, renderExtensionTemplate } from "../../extensions.js";
|
||||
import { collapseNewlines } from "../../power-user.js";
|
||||
import { debounce, getStringHash as calculateHash } from "../../utils.js";
|
||||
|
||||
const MODULE_NAME = 'vectors';
|
||||
const MIN_TO_LEAVE = 5;
|
||||
const QUERY_AMOUNT = 2;
|
||||
const LEAVE_RATIO = 0.5;
|
||||
const AMOUNT_TO_LEAVE = 5;
|
||||
const INSERT_AMOUNT = 3;
|
||||
const QUERY_TEXT_AMOUNT = 3;
|
||||
|
||||
export const EXTENSION_PROMPT_TAG = '3_vectors';
|
||||
|
||||
const settings = {
|
||||
enabled: false,
|
||||
@@ -72,7 +74,7 @@ function getStringHash(str) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Rearranges the chat based on the relevance of recent messages
|
||||
* Removes the most relevant messages from the chat and displays them in the extension prompt
|
||||
* @param {object[]} chat Array of chat messages
|
||||
*/
|
||||
async function rearrangeChat(chat) {
|
||||
@@ -88,8 +90,8 @@ async function rearrangeChat(chat) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (chat.length < MIN_TO_LEAVE) {
|
||||
console.debug(`Vectors: Not enough messages to rearrange (less than ${MIN_TO_LEAVE})`);
|
||||
if (chat.length < AMOUNT_TO_LEAVE) {
|
||||
console.debug(`Vectors: Not enough messages to rearrange (less than ${AMOUNT_TO_LEAVE})`);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -100,48 +102,34 @@ async function rearrangeChat(chat) {
|
||||
return;
|
||||
}
|
||||
|
||||
const queryHashes = await queryCollection(chatId, queryText);
|
||||
|
||||
// Sorting logic
|
||||
// 1. 50% of messages at the end stay in the same place (minimum 5)
|
||||
// 2. Messages that are in the query are rearranged to match the query order
|
||||
// 3. Messages that are not in the query and are not in the top 50% stay in the same place
|
||||
// Get the most relevant messages, excluding the last few
|
||||
const queryHashes = await queryCollection(chatId, queryText, INSERT_AMOUNT);
|
||||
const queriedMessages = [];
|
||||
const remainingMessages = [];
|
||||
const retainMessages = chat.slice(-AMOUNT_TO_LEAVE);
|
||||
|
||||
// Leave the last N messages intact
|
||||
const retainMessagesCount = Math.max(Math.floor(chat.length * LEAVE_RATIO), MIN_TO_LEAVE);
|
||||
const lastNMessages = chat.slice(-retainMessagesCount);
|
||||
|
||||
// Splitting messages into queried and remaining messages
|
||||
for (const message of chat) {
|
||||
if (lastNMessages.includes(message)) {
|
||||
if (retainMessages.includes(message)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (message.mes && queryHashes.includes(getStringHash(message.mes))) {
|
||||
queriedMessages.push(message);
|
||||
} else {
|
||||
remainingMessages.push(message);
|
||||
}
|
||||
}
|
||||
|
||||
// Rearrange queried messages to match query order
|
||||
// Order is reversed because more relevant are at the lower indices
|
||||
queriedMessages.sort((a, b) => {
|
||||
return queryHashes.indexOf(getStringHash(b.mes)) - queryHashes.indexOf(getStringHash(a.mes));
|
||||
});
|
||||
queriedMessages.sort((a, b) => queryHashes.indexOf(getStringHash(b.mes)) - queryHashes.indexOf(getStringHash(a.mes)));
|
||||
|
||||
// Construct the final rearranged chat
|
||||
const rearrangedChat = [...remainingMessages, ...queriedMessages, ...lastNMessages];
|
||||
|
||||
if (rearrangedChat.length !== chat.length) {
|
||||
console.error('Vectors: Rearranged chat length does not match original chat length! This should not happen.');
|
||||
return;
|
||||
// Remove queried messages from the original chat array
|
||||
for (const message of chat) {
|
||||
if (queriedMessages.includes(message)) {
|
||||
chat.splice(chat.indexOf(message), 1);
|
||||
}
|
||||
}
|
||||
|
||||
// Update the original chat array in-place
|
||||
chat.splice(0, chat.length, ...rearrangedChat);
|
||||
// Format queried messages into a single string
|
||||
const queriedText = 'Past events: ' + queriedMessages.map(x => collapseNewlines(`${x.name}: ${x.mes}`).trim()).join('\n\n');
|
||||
setExtensionPrompt(EXTENSION_PROMPT_TAG, queriedText, extension_prompt_types.IN_PROMPT, 0);
|
||||
} catch (error) {
|
||||
console.error('Vectors: Failed to rearrange chat', error);
|
||||
}
|
||||
@@ -151,6 +139,11 @@ window['vectors_rearrangeChat'] = rearrangeChat;
|
||||
|
||||
const onChatEvent = debounce(async () => await moduleWorker.update(), 500);
|
||||
|
||||
/**
|
||||
* Gets the text to query from the chat
|
||||
* @param {object[]} chat Chat messages
|
||||
* @returns {string} Text to query
|
||||
*/
|
||||
function getQueryText(chat) {
|
||||
let queryText = '';
|
||||
let i = 0;
|
||||
@@ -161,7 +154,7 @@ function getQueryText(chat) {
|
||||
i++;
|
||||
}
|
||||
|
||||
if (i === QUERY_AMOUNT) {
|
||||
if (i === QUERY_TEXT_AMOUNT) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -228,13 +221,14 @@ async function deleteVectorItems(collectionId, hashes) {
|
||||
/**
|
||||
* @param {string} collectionId - The collection to query
|
||||
* @param {string} searchText - The text to query
|
||||
* @param {number} topK - The number of results to return
|
||||
* @returns {Promise<number[]>} - Hashes of the results
|
||||
*/
|
||||
async function queryCollection(collectionId, searchText) {
|
||||
async function queryCollection(collectionId, searchText, topK) {
|
||||
const response = await fetch('/api/vector/query', {
|
||||
method: 'POST',
|
||||
headers: getRequestHeaders(),
|
||||
body: JSON.stringify({ collectionId, searchText }),
|
||||
body: JSON.stringify({ collectionId, searchText, topK }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
|
Reference in New Issue
Block a user