From 9d45c0a01813abda665a3a166a289aac2199d38a Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Fri, 8 Sep 2023 00:28:06 +0300 Subject: [PATCH] Add UI plugin for vectors --- public/scripts/extensions.js | 2 + public/scripts/extensions/vectors/index.js | 267 ++++++++++++++++++ .../scripts/extensions/vectors/manifest.json | 12 + .../scripts/extensions/vectors/settings.html | 14 + src/vectors.js | 6 +- 5 files changed, 298 insertions(+), 3 deletions(-) create mode 100644 public/scripts/extensions/vectors/index.js create mode 100644 public/scripts/extensions/vectors/manifest.json create mode 100644 public/scripts/extensions/vectors/settings.html diff --git a/public/scripts/extensions.js b/public/scripts/extensions.js index 0d0018d17..5e633f17b 100644 --- a/public/scripts/extensions.js +++ b/public/scripts/extensions.js @@ -153,6 +153,8 @@ const extension_settings = { }, speech_recognition: {}, rvc: {}, + hypebot: {}, + vectors: {}, }; let modules = []; diff --git a/public/scripts/extensions/vectors/index.js b/public/scripts/extensions/vectors/index.js new file mode 100644 index 000000000..3b63bf395 --- /dev/null +++ b/public/scripts/extensions/vectors/index.js @@ -0,0 +1,267 @@ +import { eventSource, event_types, getCurrentChatId, getRequestHeaders, saveSettingsDebounced } from "../../../script.js"; +import { ModuleWorkerWrapper, extension_settings, getContext, renderExtensionTemplate } from "../../extensions.js"; +import { collapseNewlines } from "../../power-user.js"; +import { debounce, getStringHash as calculateHash } from "../../utils.js"; + +const MODULE_NAME = 'vectors'; +const MIN_TO_LEAVE = 5; +const QUERY_AMOUNT = 2; +const LEAVE_RATIO = 0.5; + +const settings = { + enabled: false, +}; + +const moduleWorker = new ModuleWorkerWrapper(synchronizeChat); + +async function synchronizeChat() { + try { + if (!settings.enabled) { + return; + } + + const context = getContext(); + const chatId = getCurrentChatId(); + + if (!chatId || !Array.isArray(context.chat)) { + console.debug('Vectors: No chat selected'); + return; + } + + const hashedMessages = context.chat.filter(x => !x.is_system).map(x => ({ text: String(x.mes), hash: getStringHash(x.mes) })); + const hashesInCollection = await getSavedHashes(chatId); + + const newVectorItems = hashedMessages.filter(x => !hashesInCollection.includes(x.hash)); + const deletedHashes = hashesInCollection.filter(x => !hashedMessages.some(y => y.hash === x)); + + if (newVectorItems.length > 0) { + await insertVectorItems(chatId, newVectorItems); + console.log(`Vectors: Inserted ${newVectorItems.length} new items`); + } + + if (deletedHashes.length > 0) { + await deleteVectorItems(chatId, deletedHashes); + console.log(`Vectors: Deleted ${deletedHashes.length} old hashes`); + } + } catch (error) { + console.error('Vectors: Failed to synchronize chat', error); + } +} + +// Cache object for storing hash values +const hashCache = {}; + +/** + * Gets the hash value for a given string + * @param {string} str Input string + * @returns {number} Hash value + */ +function getStringHash(str) { + // Check if the hash is already in the cache + if (hashCache.hasOwnProperty(str)) { + return hashCache[str]; + } + + // Calculate the hash value + const hash = calculateHash(str); + + // Store the hash in the cache + hashCache[str] = hash; + + return hash; +} + +/** + * Rearranges the chat based on the relevance of recent messages + * @param {object[]} chat Array of chat messages + */ +async function rearrangeChat(chat) { + try { + if (!settings.enabled) { + return; + } + + const chatId = getCurrentChatId(); + + if (!chatId || !Array.isArray(chat)) { + console.debug('Vectors: No chat selected'); + return; + } + + if (chat.length < MIN_TO_LEAVE) { + console.debug(`Vectors: Not enough messages to rearrange (less than ${MIN_TO_LEAVE})`); + return; + } + + const queryText = getQueryText(chat); + + if (queryText.length === 0) { + console.debug('Vectors: No text to query'); + return; + } + + const queryHashes = await queryCollection(chatId, queryText); + + // Sorting logic + // 1. 50% of messages at the end stay in the same place (minimum 5) + // 2. Messages that are in the query are rearranged to match the query order + // 3. Messages that are not in the query and are not in the top 50% stay in the same place + const queriedMessages = []; + const remainingMessages = []; + + // Leave the last N messages intact + const retainMessagesCount = Math.max(Math.floor(chat.length * LEAVE_RATIO), MIN_TO_LEAVE); + const lastNMessages = chat.slice(-retainMessagesCount); + + // Splitting messages into queried and remaining messages + for (const message of chat) { + if (lastNMessages.includes(message)) { + continue; + } + + if (message.mes && queryHashes.includes(getStringHash(message.mes))) { + queriedMessages.push(message); + } else { + remainingMessages.push(message); + } + } + + // Rearrange queried messages to match query order + // Order is reversed because more relevant are at the lower indices + queriedMessages.sort((a, b) => { + return queryHashes.indexOf(getStringHash(b.mes)) - queryHashes.indexOf(getStringHash(a.mes)); + }); + + // Construct the final rearranged chat + const rearrangedChat = [...remainingMessages, ...queriedMessages, ...lastNMessages]; + + if (rearrangedChat.length !== chat.length) { + console.error('Vectors: Rearranged chat length does not match original chat length! This should not happen.'); + return; + } + + // Update the original chat array in-place + chat.splice(0, chat.length, ...rearrangedChat); + } catch (error) { + console.error('Vectors: Failed to rearrange chat', error); + } +} + +window['vectors_rearrangeChat'] = rearrangeChat; + +const onChatEvent = debounce(async () => await moduleWorker.update(), 500); + +function getQueryText(chat) { + let queryText = ''; + let i = 0; + + for (const message of chat.slice().reverse()) { + if (message.mes) { + queryText += message.mes + '\n'; + i++; + } + + if (i === QUERY_AMOUNT) { + break; + } + } + + return collapseNewlines(queryText).trim(); +} + +/** + * Gets the saved hashes for a collection +* @param {string} collectionId +* @returns {Promise} Saved hashes +*/ +async function getSavedHashes(collectionId) { + const response = await fetch('/api/vector/list', { + method: 'POST', + headers: getRequestHeaders(), + body: JSON.stringify({ collectionId }), + }); + + if (!response.ok) { + throw new Error(`Failed to get saved hashes for collection ${collectionId}`); + } + + const hashes = await response.json(); + return hashes; +} + +/** + * Inserts vector items into a collection + * @param {string} collectionId - The collection to insert into + * @param {{ hash: number, text: string }[]} items - The items to insert + * @returns {Promise} + */ +async function insertVectorItems(collectionId, items) { + const response = await fetch('/api/vector/insert', { + method: 'POST', + headers: getRequestHeaders(), + body: JSON.stringify({ collectionId, items }), + }); + + if (!response.ok) { + throw new Error(`Failed to insert vector items for collection ${collectionId}`); + } +} + +/** + * Deletes vector items from a collection + * @param {string} collectionId - The collection to delete from + * @param {number[]} hashes - The hashes of the items to delete + * @returns {Promise} + */ +async function deleteVectorItems(collectionId, hashes) { + const response = await fetch('/api/vector/delete', { + method: 'POST', + headers: getRequestHeaders(), + body: JSON.stringify({ collectionId, hashes }), + }); + + if (!response.ok) { + throw new Error(`Failed to delete vector items for collection ${collectionId}`); + } +} + +/** + * @param {string} collectionId - The collection to query + * @param {string} searchText - The text to query + * @returns {Promise} - Hashes of the results + */ +async function queryCollection(collectionId, searchText) { + const response = await fetch('/api/vector/query', { + method: 'POST', + headers: getRequestHeaders(), + body: JSON.stringify({ collectionId, searchText }), + }); + + if (!response.ok) { + throw new Error(`Failed to query collection ${collectionId}`); + } + + const results = await response.json(); + return results; +} + +jQuery(async () => { + if (!extension_settings.vectors) { + extension_settings.vectors = settings; + } + + Object.assign(settings, extension_settings.vectors); + $('#extensions_settings2').append(renderExtensionTemplate(MODULE_NAME, 'settings')); + $('#vectors_enabled').prop('checked', settings.enabled).on('input', () => { + settings.enabled = $('#vectors_enabled').prop('checked'); + Object.assign(extension_settings.vectors, settings); + saveSettingsDebounced(); + }); + + eventSource.on(event_types.CHAT_CHANGED, onChatEvent); + eventSource.on(event_types.MESSAGE_DELETED, onChatEvent); + eventSource.on(event_types.MESSAGE_EDITED, onChatEvent); + eventSource.on(event_types.MESSAGE_SENT, onChatEvent); + eventSource.on(event_types.MESSAGE_RECEIVED, onChatEvent); + eventSource.on(event_types.MESSAGE_SWIPED, onChatEvent); +}); diff --git a/public/scripts/extensions/vectors/manifest.json b/public/scripts/extensions/vectors/manifest.json new file mode 100644 index 000000000..7f84c2147 --- /dev/null +++ b/public/scripts/extensions/vectors/manifest.json @@ -0,0 +1,12 @@ +{ + "display_name": "Vector Storage", + "loading_order": 100, + "requires": [], + "optional": [], + "generate_interceptor": "vectors_rearrangeChat", + "js": "index.js", + "css": "", + "author": "Cohee#1207", + "version": "1.0.0", + "homePage": "https://github.com/SillyTavern/SillyTavern" +} diff --git a/public/scripts/extensions/vectors/settings.html b/public/scripts/extensions/vectors/settings.html new file mode 100644 index 000000000..d0e0b6294 --- /dev/null +++ b/public/scripts/extensions/vectors/settings.html @@ -0,0 +1,14 @@ +
+
+
+ Vector Storage +
+
+
+ +
+
+
diff --git a/src/vectors.js b/src/vectors.js index f09444ec9..81c2a4161 100644 --- a/src/vectors.js +++ b/src/vectors.js @@ -12,14 +12,14 @@ class EmbeddingModel { /** * @type {encoder.UniversalSentenceEncoder} - The embedding model */ - #model; + model; async get() { - if (!this.#model) { + if (!this.model) { this.model = await encoder.load(); } - return this.#model; + return this.model; } }