diff --git a/public/script.js b/public/script.js index b8292dc73..e9945ee9a 100644 --- a/public/script.js +++ b/public/script.js @@ -382,10 +382,7 @@ const system_message_types = { }; const extension_prompt_types = { - /** - * @deprecated Outdated term. In reality it's "after main prompt or story string" - */ - AFTER_SCENARIO: 0, + IN_PROMPT: 0, IN_CHAT: 1 }; @@ -2533,7 +2530,7 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject, addPersonaDescriptionExtensionPrompt(); // Call combined AN into Generate let allAnchors = getAllExtensionPrompts(); - const afterScenarioAnchor = getExtensionPrompt(extension_prompt_types.AFTER_SCENARIO); + const afterScenarioAnchor = getExtensionPrompt(extension_prompt_types.IN_PROMPT); let zeroDepthAnchor = getExtensionPrompt(extension_prompt_types.IN_CHAT, 0, ' '); const storyStringParams = { @@ -5591,7 +5588,7 @@ function select_rm_characters() { * @param {number} position Insertion position. 0 is after story string, 1 is in-chat with custom depth. * @param {number} depth Insertion depth. 0 represets the last message in context. Expected values up to 100. */ -function setExtensionPrompt(key, value, position, depth) { +export function setExtensionPrompt(key, value, position, depth) { extension_prompts[key] = { value: String(value), position: Number(position), depth: Number(depth) }; } diff --git a/public/scripts/extensions/infinity-context/index.js b/public/scripts/extensions/infinity-context/index.js index b06309e5b..303b2eee0 100644 --- a/public/scripts/extensions/infinity-context/index.js +++ b/public/scripts/extensions/infinity-context/index.js @@ -739,7 +739,7 @@ window.chromadb_interceptGeneration = async (chat, maxContext) => { // No memories? No prompt. const promptBlob = (tokenApprox == 0) ? "" : wrapperMsg.replace('{{memories}}', allMemoryBlob); console.debug("CHROMADB: prompt blob: %o", promptBlob); - context.setExtensionPrompt(MODULE_NAME, promptBlob, extension_prompt_types.AFTER_SCENARIO); + context.setExtensionPrompt(MODULE_NAME, promptBlob, extension_prompt_types.IN_PROMPT); } if (selectedStrategy === 'custom') { const context = getContext(); diff --git a/public/scripts/extensions/memory/index.js b/public/scripts/extensions/memory/index.js index 8dcf841d5..d2eaebf4f 100644 --- a/public/scripts/extensions/memory/index.js +++ b/public/scripts/extensions/memory/index.js @@ -63,7 +63,7 @@ const defaultSettings = { source: summary_sources.extras, prompt: defaultPrompt, template: defaultTemplate, - position: extension_prompt_types.AFTER_SCENARIO, + position: extension_prompt_types.IN_PROMPT, depth: 2, promptWords: 200, promptMinWords: 25, diff --git a/public/scripts/extensions/vectors/index.js b/public/scripts/extensions/vectors/index.js index 3b63bf395..cf78cf7ee 100644 --- a/public/scripts/extensions/vectors/index.js +++ b/public/scripts/extensions/vectors/index.js @@ -1,12 +1,14 @@ -import { eventSource, event_types, getCurrentChatId, getRequestHeaders, saveSettingsDebounced } from "../../../script.js"; +import { eventSource, event_types, extension_prompt_types, getCurrentChatId, getRequestHeaders, saveSettingsDebounced, setExtensionPrompt } from "../../../script.js"; import { ModuleWorkerWrapper, extension_settings, getContext, renderExtensionTemplate } from "../../extensions.js"; import { collapseNewlines } from "../../power-user.js"; import { debounce, getStringHash as calculateHash } from "../../utils.js"; const MODULE_NAME = 'vectors'; -const MIN_TO_LEAVE = 5; -const QUERY_AMOUNT = 2; -const LEAVE_RATIO = 0.5; +const AMOUNT_TO_LEAVE = 5; +const INSERT_AMOUNT = 3; +const QUERY_TEXT_AMOUNT = 3; + +export const EXTENSION_PROMPT_TAG = '3_vectors'; const settings = { enabled: false, @@ -72,7 +74,7 @@ function getStringHash(str) { } /** - * Rearranges the chat based on the relevance of recent messages + * Removes the most relevant messages from the chat and displays them in the extension prompt * @param {object[]} chat Array of chat messages */ async function rearrangeChat(chat) { @@ -88,8 +90,8 @@ async function rearrangeChat(chat) { return; } - if (chat.length < MIN_TO_LEAVE) { - console.debug(`Vectors: Not enough messages to rearrange (less than ${MIN_TO_LEAVE})`); + if (chat.length < AMOUNT_TO_LEAVE) { + console.debug(`Vectors: Not enough messages to rearrange (less than ${AMOUNT_TO_LEAVE})`); return; } @@ -100,48 +102,34 @@ async function rearrangeChat(chat) { return; } - const queryHashes = await queryCollection(chatId, queryText); - - // Sorting logic - // 1. 50% of messages at the end stay in the same place (minimum 5) - // 2. Messages that are in the query are rearranged to match the query order - // 3. Messages that are not in the query and are not in the top 50% stay in the same place + // Get the most relevant messages, excluding the last few + const queryHashes = await queryCollection(chatId, queryText, INSERT_AMOUNT); const queriedMessages = []; - const remainingMessages = []; + const retainMessages = chat.slice(-AMOUNT_TO_LEAVE); - // Leave the last N messages intact - const retainMessagesCount = Math.max(Math.floor(chat.length * LEAVE_RATIO), MIN_TO_LEAVE); - const lastNMessages = chat.slice(-retainMessagesCount); - - // Splitting messages into queried and remaining messages for (const message of chat) { - if (lastNMessages.includes(message)) { + if (retainMessages.includes(message)) { continue; } - if (message.mes && queryHashes.includes(getStringHash(message.mes))) { queriedMessages.push(message); - } else { - remainingMessages.push(message); } } // Rearrange queried messages to match query order // Order is reversed because more relevant are at the lower indices - queriedMessages.sort((a, b) => { - return queryHashes.indexOf(getStringHash(b.mes)) - queryHashes.indexOf(getStringHash(a.mes)); - }); + queriedMessages.sort((a, b) => queryHashes.indexOf(getStringHash(b.mes)) - queryHashes.indexOf(getStringHash(a.mes))); - // Construct the final rearranged chat - const rearrangedChat = [...remainingMessages, ...queriedMessages, ...lastNMessages]; - - if (rearrangedChat.length !== chat.length) { - console.error('Vectors: Rearranged chat length does not match original chat length! This should not happen.'); - return; + // Remove queried messages from the original chat array + for (const message of chat) { + if (queriedMessages.includes(message)) { + chat.splice(chat.indexOf(message), 1); + } } - // Update the original chat array in-place - chat.splice(0, chat.length, ...rearrangedChat); + // Format queried messages into a single string + const queriedText = 'Past events: ' + queriedMessages.map(x => collapseNewlines(`${x.name}: ${x.mes}`).trim()).join('\n\n'); + setExtensionPrompt(EXTENSION_PROMPT_TAG, queriedText, extension_prompt_types.IN_PROMPT, 0); } catch (error) { console.error('Vectors: Failed to rearrange chat', error); } @@ -151,6 +139,11 @@ window['vectors_rearrangeChat'] = rearrangeChat; const onChatEvent = debounce(async () => await moduleWorker.update(), 500); +/** + * Gets the text to query from the chat + * @param {object[]} chat Chat messages + * @returns {string} Text to query + */ function getQueryText(chat) { let queryText = ''; let i = 0; @@ -161,7 +154,7 @@ function getQueryText(chat) { i++; } - if (i === QUERY_AMOUNT) { + if (i === QUERY_TEXT_AMOUNT) { break; } } @@ -228,13 +221,14 @@ async function deleteVectorItems(collectionId, hashes) { /** * @param {string} collectionId - The collection to query * @param {string} searchText - The text to query + * @param {number} topK - The number of results to return * @returns {Promise} - Hashes of the results */ -async function queryCollection(collectionId, searchText) { +async function queryCollection(collectionId, searchText, topK) { const response = await fetch('/api/vector/query', { method: 'POST', headers: getRequestHeaders(), - body: JSON.stringify({ collectionId, searchText }), + body: JSON.stringify({ collectionId, searchText, topK }), }); if (!response.ok) { diff --git a/public/scripts/openai.js b/public/scripts/openai.js index 3ac66f0df..ce265fd41 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -696,6 +696,14 @@ function preparePromptsForChatCompletion({Scenario, charPersonality, name2, worl identifier: 'authorsNote' }); + // Vectors Memory + const vectorsMemory = extensionPrompts['3_vectors']; + if (vectorsMemory && vectorsMemory.value) systemPrompts.push({ + role: 'system', + content: vectorsMemory.value, + identifier: 'vectorsMemory', + }); + // Persona Description if (power_user.persona_description && power_user.persona_description_position === persona_description_positions.IN_PROMPT) { systemPrompts.push({ role: 'system', content: power_user.persona_description, identifier: 'personaDescription' }); diff --git a/src/vectors.js b/src/vectors.js index 81c2a4161..26265bcb0 100644 --- a/src/vectors.js +++ b/src/vectors.js @@ -23,10 +23,6 @@ class EmbeddingModel { } } -/** - * Hard limit on the number of results to return from the vector search. - */ -const TOP_K = 100; const model = new EmbeddingModel(); /** @@ -100,17 +96,18 @@ async function deleteVectorItems(collectionId, hashes) { /** * Gets the hashes of the items in the vector collection that match the search text - * @param {string} collectionId - * @param {string} searchText + * @param {string} collectionId - The collection ID + * @param {string} searchText - The text to search for + * @param {number} topK - The number of results to return * @returns {Promise} - The hashes of the items that match the search text */ -async function queryCollection(collectionId, searchText) { +async function queryCollection(collectionId, searchText, topK) { const index = await getIndex(collectionId); const use = await model.get(); const tensor = await use.embed(searchText); const vector = Array.from(await tensor.data()); - const result = await index.queryItems(vector, TOP_K); + const result = await index.queryItems(vector, topK); const hashes = result.map(x => Number(x.item.metadata.hash)); return hashes; } @@ -129,8 +126,9 @@ async function registerEndpoints(app, jsonParser) { const collectionId = String(req.body.collectionId); const searchText = String(req.body.searchText); + const topK = Number(req.body.topK) || 10; - const results = await queryCollection(collectionId, searchText); + const results = await queryCollection(collectionId, searchText, topK); return res.json(results); } catch (error) { console.error(error);