mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-02-10 00:50:43 +01:00
Change insertion strategy to an extension block
This commit is contained in:
parent
9d45c0a018
commit
96df705409
@ -382,10 +382,7 @@ const system_message_types = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const extension_prompt_types = {
|
const extension_prompt_types = {
|
||||||
/**
|
IN_PROMPT: 0,
|
||||||
* @deprecated Outdated term. In reality it's "after main prompt or story string"
|
|
||||||
*/
|
|
||||||
AFTER_SCENARIO: 0,
|
|
||||||
IN_CHAT: 1
|
IN_CHAT: 1
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -2533,7 +2530,7 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
|
|||||||
addPersonaDescriptionExtensionPrompt();
|
addPersonaDescriptionExtensionPrompt();
|
||||||
// Call combined AN into Generate
|
// Call combined AN into Generate
|
||||||
let allAnchors = getAllExtensionPrompts();
|
let allAnchors = getAllExtensionPrompts();
|
||||||
const afterScenarioAnchor = getExtensionPrompt(extension_prompt_types.AFTER_SCENARIO);
|
const afterScenarioAnchor = getExtensionPrompt(extension_prompt_types.IN_PROMPT);
|
||||||
let zeroDepthAnchor = getExtensionPrompt(extension_prompt_types.IN_CHAT, 0, ' ');
|
let zeroDepthAnchor = getExtensionPrompt(extension_prompt_types.IN_CHAT, 0, ' ');
|
||||||
|
|
||||||
const storyStringParams = {
|
const storyStringParams = {
|
||||||
@ -5591,7 +5588,7 @@ function select_rm_characters() {
|
|||||||
* @param {number} position Insertion position. 0 is after story string, 1 is in-chat with custom depth.
|
* @param {number} position Insertion position. 0 is after story string, 1 is in-chat with custom depth.
|
||||||
* @param {number} depth Insertion depth. 0 represets the last message in context. Expected values up to 100.
|
* @param {number} depth Insertion depth. 0 represets the last message in context. Expected values up to 100.
|
||||||
*/
|
*/
|
||||||
function setExtensionPrompt(key, value, position, depth) {
|
export function setExtensionPrompt(key, value, position, depth) {
|
||||||
extension_prompts[key] = { value: String(value), position: Number(position), depth: Number(depth) };
|
extension_prompts[key] = { value: String(value), position: Number(position), depth: Number(depth) };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -739,7 +739,7 @@ window.chromadb_interceptGeneration = async (chat, maxContext) => {
|
|||||||
// No memories? No prompt.
|
// No memories? No prompt.
|
||||||
const promptBlob = (tokenApprox == 0) ? "" : wrapperMsg.replace('{{memories}}', allMemoryBlob);
|
const promptBlob = (tokenApprox == 0) ? "" : wrapperMsg.replace('{{memories}}', allMemoryBlob);
|
||||||
console.debug("CHROMADB: prompt blob: %o", promptBlob);
|
console.debug("CHROMADB: prompt blob: %o", promptBlob);
|
||||||
context.setExtensionPrompt(MODULE_NAME, promptBlob, extension_prompt_types.AFTER_SCENARIO);
|
context.setExtensionPrompt(MODULE_NAME, promptBlob, extension_prompt_types.IN_PROMPT);
|
||||||
}
|
}
|
||||||
if (selectedStrategy === 'custom') {
|
if (selectedStrategy === 'custom') {
|
||||||
const context = getContext();
|
const context = getContext();
|
||||||
|
@ -63,7 +63,7 @@ const defaultSettings = {
|
|||||||
source: summary_sources.extras,
|
source: summary_sources.extras,
|
||||||
prompt: defaultPrompt,
|
prompt: defaultPrompt,
|
||||||
template: defaultTemplate,
|
template: defaultTemplate,
|
||||||
position: extension_prompt_types.AFTER_SCENARIO,
|
position: extension_prompt_types.IN_PROMPT,
|
||||||
depth: 2,
|
depth: 2,
|
||||||
promptWords: 200,
|
promptWords: 200,
|
||||||
promptMinWords: 25,
|
promptMinWords: 25,
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
import { eventSource, event_types, getCurrentChatId, getRequestHeaders, saveSettingsDebounced } from "../../../script.js";
|
import { eventSource, event_types, extension_prompt_types, getCurrentChatId, getRequestHeaders, saveSettingsDebounced, setExtensionPrompt } from "../../../script.js";
|
||||||
import { ModuleWorkerWrapper, extension_settings, getContext, renderExtensionTemplate } from "../../extensions.js";
|
import { ModuleWorkerWrapper, extension_settings, getContext, renderExtensionTemplate } from "../../extensions.js";
|
||||||
import { collapseNewlines } from "../../power-user.js";
|
import { collapseNewlines } from "../../power-user.js";
|
||||||
import { debounce, getStringHash as calculateHash } from "../../utils.js";
|
import { debounce, getStringHash as calculateHash } from "../../utils.js";
|
||||||
|
|
||||||
const MODULE_NAME = 'vectors';
|
const MODULE_NAME = 'vectors';
|
||||||
const MIN_TO_LEAVE = 5;
|
const AMOUNT_TO_LEAVE = 5;
|
||||||
const QUERY_AMOUNT = 2;
|
const INSERT_AMOUNT = 3;
|
||||||
const LEAVE_RATIO = 0.5;
|
const QUERY_TEXT_AMOUNT = 3;
|
||||||
|
|
||||||
|
export const EXTENSION_PROMPT_TAG = '3_vectors';
|
||||||
|
|
||||||
const settings = {
|
const settings = {
|
||||||
enabled: false,
|
enabled: false,
|
||||||
@ -72,7 +74,7 @@ function getStringHash(str) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Rearranges the chat based on the relevance of recent messages
|
* Removes the most relevant messages from the chat and displays them in the extension prompt
|
||||||
* @param {object[]} chat Array of chat messages
|
* @param {object[]} chat Array of chat messages
|
||||||
*/
|
*/
|
||||||
async function rearrangeChat(chat) {
|
async function rearrangeChat(chat) {
|
||||||
@ -88,8 +90,8 @@ async function rearrangeChat(chat) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (chat.length < MIN_TO_LEAVE) {
|
if (chat.length < AMOUNT_TO_LEAVE) {
|
||||||
console.debug(`Vectors: Not enough messages to rearrange (less than ${MIN_TO_LEAVE})`);
|
console.debug(`Vectors: Not enough messages to rearrange (less than ${AMOUNT_TO_LEAVE})`);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -100,48 +102,34 @@ async function rearrangeChat(chat) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const queryHashes = await queryCollection(chatId, queryText);
|
// Get the most relevant messages, excluding the last few
|
||||||
|
const queryHashes = await queryCollection(chatId, queryText, INSERT_AMOUNT);
|
||||||
// Sorting logic
|
|
||||||
// 1. 50% of messages at the end stay in the same place (minimum 5)
|
|
||||||
// 2. Messages that are in the query are rearranged to match the query order
|
|
||||||
// 3. Messages that are not in the query and are not in the top 50% stay in the same place
|
|
||||||
const queriedMessages = [];
|
const queriedMessages = [];
|
||||||
const remainingMessages = [];
|
const retainMessages = chat.slice(-AMOUNT_TO_LEAVE);
|
||||||
|
|
||||||
// Leave the last N messages intact
|
|
||||||
const retainMessagesCount = Math.max(Math.floor(chat.length * LEAVE_RATIO), MIN_TO_LEAVE);
|
|
||||||
const lastNMessages = chat.slice(-retainMessagesCount);
|
|
||||||
|
|
||||||
// Splitting messages into queried and remaining messages
|
|
||||||
for (const message of chat) {
|
for (const message of chat) {
|
||||||
if (lastNMessages.includes(message)) {
|
if (retainMessages.includes(message)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (message.mes && queryHashes.includes(getStringHash(message.mes))) {
|
if (message.mes && queryHashes.includes(getStringHash(message.mes))) {
|
||||||
queriedMessages.push(message);
|
queriedMessages.push(message);
|
||||||
} else {
|
|
||||||
remainingMessages.push(message);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rearrange queried messages to match query order
|
// Rearrange queried messages to match query order
|
||||||
// Order is reversed because more relevant are at the lower indices
|
// Order is reversed because more relevant are at the lower indices
|
||||||
queriedMessages.sort((a, b) => {
|
queriedMessages.sort((a, b) => queryHashes.indexOf(getStringHash(b.mes)) - queryHashes.indexOf(getStringHash(a.mes)));
|
||||||
return queryHashes.indexOf(getStringHash(b.mes)) - queryHashes.indexOf(getStringHash(a.mes));
|
|
||||||
});
|
|
||||||
|
|
||||||
// Construct the final rearranged chat
|
// Remove queried messages from the original chat array
|
||||||
const rearrangedChat = [...remainingMessages, ...queriedMessages, ...lastNMessages];
|
for (const message of chat) {
|
||||||
|
if (queriedMessages.includes(message)) {
|
||||||
if (rearrangedChat.length !== chat.length) {
|
chat.splice(chat.indexOf(message), 1);
|
||||||
console.error('Vectors: Rearranged chat length does not match original chat length! This should not happen.');
|
}
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the original chat array in-place
|
// Format queried messages into a single string
|
||||||
chat.splice(0, chat.length, ...rearrangedChat);
|
const queriedText = 'Past events: ' + queriedMessages.map(x => collapseNewlines(`${x.name}: ${x.mes}`).trim()).join('\n\n');
|
||||||
|
setExtensionPrompt(EXTENSION_PROMPT_TAG, queriedText, extension_prompt_types.IN_PROMPT, 0);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Vectors: Failed to rearrange chat', error);
|
console.error('Vectors: Failed to rearrange chat', error);
|
||||||
}
|
}
|
||||||
@ -151,6 +139,11 @@ window['vectors_rearrangeChat'] = rearrangeChat;
|
|||||||
|
|
||||||
const onChatEvent = debounce(async () => await moduleWorker.update(), 500);
|
const onChatEvent = debounce(async () => await moduleWorker.update(), 500);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the text to query from the chat
|
||||||
|
* @param {object[]} chat Chat messages
|
||||||
|
* @returns {string} Text to query
|
||||||
|
*/
|
||||||
function getQueryText(chat) {
|
function getQueryText(chat) {
|
||||||
let queryText = '';
|
let queryText = '';
|
||||||
let i = 0;
|
let i = 0;
|
||||||
@ -161,7 +154,7 @@ function getQueryText(chat) {
|
|||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (i === QUERY_AMOUNT) {
|
if (i === QUERY_TEXT_AMOUNT) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -228,13 +221,14 @@ async function deleteVectorItems(collectionId, hashes) {
|
|||||||
/**
|
/**
|
||||||
* @param {string} collectionId - The collection to query
|
* @param {string} collectionId - The collection to query
|
||||||
* @param {string} searchText - The text to query
|
* @param {string} searchText - The text to query
|
||||||
|
* @param {number} topK - The number of results to return
|
||||||
* @returns {Promise<number[]>} - Hashes of the results
|
* @returns {Promise<number[]>} - Hashes of the results
|
||||||
*/
|
*/
|
||||||
async function queryCollection(collectionId, searchText) {
|
async function queryCollection(collectionId, searchText, topK) {
|
||||||
const response = await fetch('/api/vector/query', {
|
const response = await fetch('/api/vector/query', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: getRequestHeaders(),
|
headers: getRequestHeaders(),
|
||||||
body: JSON.stringify({ collectionId, searchText }),
|
body: JSON.stringify({ collectionId, searchText, topK }),
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
|
@ -696,6 +696,14 @@ function preparePromptsForChatCompletion({Scenario, charPersonality, name2, worl
|
|||||||
identifier: 'authorsNote'
|
identifier: 'authorsNote'
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Vectors Memory
|
||||||
|
const vectorsMemory = extensionPrompts['3_vectors'];
|
||||||
|
if (vectorsMemory && vectorsMemory.value) systemPrompts.push({
|
||||||
|
role: 'system',
|
||||||
|
content: vectorsMemory.value,
|
||||||
|
identifier: 'vectorsMemory',
|
||||||
|
});
|
||||||
|
|
||||||
// Persona Description
|
// Persona Description
|
||||||
if (power_user.persona_description && power_user.persona_description_position === persona_description_positions.IN_PROMPT) {
|
if (power_user.persona_description && power_user.persona_description_position === persona_description_positions.IN_PROMPT) {
|
||||||
systemPrompts.push({ role: 'system', content: power_user.persona_description, identifier: 'personaDescription' });
|
systemPrompts.push({ role: 'system', content: power_user.persona_description, identifier: 'personaDescription' });
|
||||||
|
@ -23,10 +23,6 @@ class EmbeddingModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Hard limit on the number of results to return from the vector search.
|
|
||||||
*/
|
|
||||||
const TOP_K = 100;
|
|
||||||
const model = new EmbeddingModel();
|
const model = new EmbeddingModel();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -100,17 +96,18 @@ async function deleteVectorItems(collectionId, hashes) {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets the hashes of the items in the vector collection that match the search text
|
* Gets the hashes of the items in the vector collection that match the search text
|
||||||
* @param {string} collectionId
|
* @param {string} collectionId - The collection ID
|
||||||
* @param {string} searchText
|
* @param {string} searchText - The text to search for
|
||||||
|
* @param {number} topK - The number of results to return
|
||||||
* @returns {Promise<number[]>} - The hashes of the items that match the search text
|
* @returns {Promise<number[]>} - The hashes of the items that match the search text
|
||||||
*/
|
*/
|
||||||
async function queryCollection(collectionId, searchText) {
|
async function queryCollection(collectionId, searchText, topK) {
|
||||||
const index = await getIndex(collectionId);
|
const index = await getIndex(collectionId);
|
||||||
const use = await model.get();
|
const use = await model.get();
|
||||||
const tensor = await use.embed(searchText);
|
const tensor = await use.embed(searchText);
|
||||||
const vector = Array.from(await tensor.data());
|
const vector = Array.from(await tensor.data());
|
||||||
|
|
||||||
const result = await index.queryItems(vector, TOP_K);
|
const result = await index.queryItems(vector, topK);
|
||||||
const hashes = result.map(x => Number(x.item.metadata.hash));
|
const hashes = result.map(x => Number(x.item.metadata.hash));
|
||||||
return hashes;
|
return hashes;
|
||||||
}
|
}
|
||||||
@ -129,8 +126,9 @@ async function registerEndpoints(app, jsonParser) {
|
|||||||
|
|
||||||
const collectionId = String(req.body.collectionId);
|
const collectionId = String(req.body.collectionId);
|
||||||
const searchText = String(req.body.searchText);
|
const searchText = String(req.body.searchText);
|
||||||
|
const topK = Number(req.body.topK) || 10;
|
||||||
|
|
||||||
const results = await queryCollection(collectionId, searchText);
|
const results = await queryCollection(collectionId, searchText, topK);
|
||||||
return res.json(results);
|
return res.json(results);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(error);
|
console.error(error);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user