diff --git a/public/script.js b/public/script.js index c660a995f..2ab1f0e15 100644 --- a/public/script.js +++ b/public/script.js @@ -118,7 +118,7 @@ import { isElementInViewport, } from "./scripts/utils.js"; -import { extension_settings, loadExtensionSettings } from "./scripts/extensions.js"; +import { extension_settings, loadExtensionSettings, runGenerationInterceptors } from "./scripts/extensions.js"; import { executeSlashCommands, getSlashCommandsHelp, registerSlashCommand } from "./scripts/slash-commands.js"; import { tag_map, @@ -180,6 +180,7 @@ export { getStoppingStrings, getStatus, reloadMarkdownProcessor, + getCurrentChatId, chat, this_chid, selected_button, @@ -506,6 +507,15 @@ function reloadMarkdownProcessor(render_formulas = false) { return converter; } +function getCurrentChatId() { + if (selected_group) { + return groups.find(x => x.id == selected_group)?.chat_id; + } + else if (this_chid) { + return characters[this_chid].chat; + } +} + const CHARACTERS_PER_TOKEN_RATIO = 3.35; const talkativeness_default = 0.5; @@ -1812,6 +1822,8 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject, if (type === 'swipe') { coreChat.pop(); } + + await runGenerationInterceptors(coreChat); console.log(`Core/all messages: ${coreChat.length}/${chat.length}`); if (main_api === 'openai') { diff --git a/public/scripts/extensions.js b/public/scripts/extensions.js index fc06649c9..2ce4d4b4c 100644 --- a/public/scripts/extensions.js +++ b/public/scripts/extensions.js @@ -4,6 +4,7 @@ export { getContext, getApiUrl, loadExtensionSettings, + runGenerationInterceptors, defaultRequestArgs, modules, extension_settings, @@ -26,6 +27,7 @@ const extension_settings = { dice: {}, tts: {}, sd: {}, + chromadb: {}, }; let modules = []; @@ -316,6 +318,19 @@ async function loadExtensionSettings(settings) { } } +async function runGenerationInterceptors(chat) { + for (const manifest of Object.values(manifests)) { + const interceptorKey = manifest.generate_interceptor; + if (typeof window[interceptorKey] === 'function') { + try { + await window[interceptorKey](chat); + } catch(e) { + console.error(`Failed running interceptor for ${manifest.display_name}`, e); + } + } + } +} + $(document).ready(async function () { setTimeout(function () { addExtensionsButtonAndMenu(); diff --git a/public/scripts/extensions/infinity-context/index.js b/public/scripts/extensions/infinity-context/index.js new file mode 100644 index 000000000..f813e70b5 --- /dev/null +++ b/public/scripts/extensions/infinity-context/index.js @@ -0,0 +1,168 @@ +import { saveSettingsDebounced, getCurrentChatId } from "../../../script.js"; +import { getApiUrl, extension_settings } from "../../extensions.js"; +import { splitRecursive } from "../../utils.js"; +export { MODULE_NAME }; + +const MODULE_NAME = 'chromadb'; + +const defaultSettings = { + keep_context: 10, + keep_context_min: 1, + keep_context_max: 100, + keep_context_step: 1, + + n_results: 20, + n_results_min: 1, + n_results_max: 100, + n_results_step: 1, + + split_length: 384, + split_length_min: 64, + split_length_max: 4096, + split_length_step: 64, +}; + +const postHeaders = { + 'Content-Type': 'application/json', + 'Bypass-Tunnel-Reminder': 'bypass', +}; + +async function loadSettings() { + if (Object.keys(extension_settings.chromadb).length === 0) { + Object.assign(extension_settings.chromadb, defaultSettings); + } + + $('#chromadb_keep_context').val(extension_settings.chromadb.keep_context).trigger('input'); + $('#chromadb_n_results').val(extension_settings.chromadb.n_results).trigger('input'); + $('#chromadb_split_length').val(extension_settings.chromadb.split_length).trigger('input'); +} + +function onKeepContextInput() { + extension_settings.chromadb.keep_context = Number($('#chromadb_keep_context').val()); + $('#chromadb_keep_context_value').text(extension_settings.chromadb.keep_context); + saveSettingsDebounced(); +} + +function onNResultsInput() { + extension_settings.chromadb.n_results = Number($('#chromadb_n_results').val()); + $('#chromadb_n_results_value').text(extension_settings.chromadb.n_results); + saveSettingsDebounced(); +} + +function onSplitLengthInput() { + extension_settings.chromadb.split_length = Number($('#chromadb_split_length').val()); + $('#chromadb_split_length_value').text(extension_settings.chromadb.split_length); + saveSettingsDebounced(); +} + +async function addMessages(chat_id, messages) { + const url = new URL(getApiUrl()); + url.pathname = '/api/chromadb'; + + const messagesDeepCopy = JSON.parse(JSON.stringify(messages)); + const splittedMessages = []; + + let id = 0; + messagesDeepCopy.forEach(m => { + const split = splitRecursive(m.mes, extension_settings.chromadb.split_length); + splittedMessages.push(...split.map(text => ({ + ...m, + mes: text, + send_date: id, + id: `msg-${id++}`, + }))); + }); + + const transformedMessages = splittedMessages.map((m) => ({ + id: m.id, + role: m.is_user ? 'user' : 'assistant', + content: m.mes, + date: m.send_date, + meta: JSON.stringify(m), + })); + + const addMessagesResult = await fetch(url, { + method: 'POST', + headers: postHeaders, + body: JSON.stringify({ chat_id, messages: transformedMessages }), + }); + + if (addMessagesResult.ok) { + const addMessagesData = await addMessagesResult.json(); + + return addMessagesData; // { count: 1 } + } + + return { count: 0 }; +} + +async function queryMessages(chat_id, query) { + const url = new URL(getApiUrl()); + url.pathname = '/api/chromadb/query'; + + const queryMessagesResult = await fetch(url, { + method: 'POST', + headers: postHeaders, + body: JSON.stringify({ chat_id, query, n_results: extension_settings.chromadb.n_results }), + }); + + if (queryMessagesResult.ok) { + const queryMessagesData = await queryMessagesResult.json(); + + return queryMessagesData; + } + + return []; +} + +window.chromadb_interceptGeneration = async (chat) => { + const currentChatId = getCurrentChatId(); + + if (currentChatId) { + const messagesToStore = chat.slice(0, -extension_settings.chromadb.keep_context); + + if (messagesToStore.length > 0) { + await addMessages(currentChatId, messagesToStore); + + const lastMessage = chat[chat.length - 1]; + + if (lastMessage) { + const queriedMessages = await queryMessages(currentChatId, lastMessage.mes); + + queriedMessages.sort((a, b) => a.date - b.date); + + const newChat = queriedMessages.map(m => JSON.parse(m.meta)); + + chat.splice(0, messagesToStore.length, ...newChat); + + console.log('ChromaDB chat after injection', chat); + } + } + } +} + +jQuery(async () => { + const settingsHtml = ` +
+
+
+ Infinity Context +
+
+
+ + + + + + +
+
`; + + $('#extensions_settings').append(settingsHtml); + $('#chromadb_keep_context').on('input', onKeepContextInput); + $('#chromadb_n_results').on('input', onNResultsInput); + $('#chromadb_split_length').on('input', onSplitLengthInput); + + await loadSettings(); +}); diff --git a/public/scripts/extensions/infinity-context/manifest.json b/public/scripts/extensions/infinity-context/manifest.json new file mode 100644 index 000000000..463e68199 --- /dev/null +++ b/public/scripts/extensions/infinity-context/manifest.json @@ -0,0 +1,14 @@ +{ + "display_name": "Infinity Context", + "loading_order": 11, + "requires": [ + "chromadb" + ], + "optional": [], + "generate_interceptor": "chromadb_interceptGeneration", + "js": "index.js", + "css": "style.css", + "author": "maceter636@proton.me", + "version": "1.0.0", + "homePage": "https://github.com/Cohee1207/SillyTavern" +} \ No newline at end of file diff --git a/public/scripts/extensions/infinity-context/style.css b/public/scripts/extensions/infinity-context/style.css new file mode 100644 index 000000000..e69de29bb diff --git a/public/scripts/utils.js b/public/scripts/utils.js index e0cb97f72..1bf9a15ff 100644 --- a/public/scripts/utils.js +++ b/public/scripts/utils.js @@ -243,5 +243,36 @@ export function countOccurrences(string, character) { } export function isOdd(number) { - return number % 2 !== 0; -} \ No newline at end of file + return number % 2 !== 0; +} + +/** Split string to parts no more than length in size */ +export function splitRecursive(input, length, delimitiers = ['\n\n', '\n', ' ', '']) { + const delim = delimitiers[0] ?? ''; + const parts = input.split(delim); + + const flatParts = parts.flatMap(p => { + if (p.length < length) return p; + return splitRecursive(input, length, delimitiers.slice(1)); + }); + + // Merge short chunks + const result = []; + let currentChunk = ''; + for (let i = 0; i < flatParts.length;) { + currentChunk = flatParts[i]; + let j = i + 1; + while (j < flatParts.length) { + const nextChunk = flatParts[j]; + if (currentChunk.length + nextChunk.length + delim.length <= length) { + currentChunk += delim + nextChunk; + } else { + break; + } + j++; + } + i = j; + result.push(currentChunk); + } + return result; +}