From 18641ea3d2278637ea9a0efd788abf63e568575b Mon Sep 17 00:00:00 2001 From: maver Date: Tue, 13 Jun 2023 20:48:06 +0200 Subject: [PATCH] Add prototype for prompt manager token management --- public/scripts/PromptManager.js | 109 +------------------ public/scripts/openai.js | 182 +++++++++++++++++++++++++++++--- 2 files changed, 171 insertions(+), 120 deletions(-) diff --git a/public/scripts/PromptManager.js b/public/scripts/PromptManager.js index 7de80214c..70c4e093a 100644 --- a/public/scripts/PromptManager.js +++ b/public/scripts/PromptManager.js @@ -3,102 +3,6 @@ import {DraggablePromptListModule as DraggableList} from "./DraggableList.js"; import {eventSource, substituteParams} from "../script.js"; import {TokenHandler} from "./openai.js"; -// Thrown by ChatCompletion when a requested prompt couldn't be found. -class IdentifierNotFoundError extends Error { - constructor(identifier) { - super(`Identifier ${identifier} not found`); - this.name = 'IdentifierNotFoundError'; - } -} - -/** - * OpenAI API chat completion representation - * const map = [{identifier: 'example', message: {role: 'system', content: 'exampleContent'}}, ...]; - * - * @see https://platform.openai.com/docs/guides/gpt/chat-completions-api - */ -const ChatCompletion = { - new() { - return { - tokenHandler: null, - map: [], - add(identifier, message) { - this.map.push({identifier, message}); - return this; - }, - get(identifier) { - const index = this.getMessageIndex(identifier); - return this.assertIndex(index, identifier).map[index]; - }, - insertBefore(identifier, insertIdentifier, insert) { - const index = this.getMessageIndex(identifier); - this.map.splice(this.assertIndex(index, identifier), 0, { - identifier: insertIdentifier, - message: insert - }); - return this; - }, - insertAfter(identifier, insertIdentifier, insert) { - const index = this.getMessageIndex(identifier); - this.map.splice(this.assertIndex(index, identifier) + 1, 0, { - identifier: insertIdentifier, - message: insert - }); - return this; - }, - replace(identifier, replacement) { - const index = this.getMessageIndex(identifier); - this.map[this.assertIndex(index, identifier)] = {identifier, message: replacement}; - return this; - }, - remove(identifier) { - const index = this.getMessageIndex(identifier); - this.map.splice(this.assertIndex(index, identifier), 1); - return this; - }, - assertIndex(index, identifier) { - if (index === -1) { - throw new IdentifierNotFoundError(`Identifier ${identifier} not found`); - } - return index; - }, - getMessageIndex(identifier) { - return this.map.findIndex(message => message.identifier === identifier); - }, - makeSystemMessage(content) { - return this.makeMessage('system', content); - }, - makeUserMessage(content) { - return this.makeMessage('user', content); - }, - makeAssistantMessage(content) { - return this.makeMessage('assistant', content); - }, - makeMessage(role, content) { - return {role: role, content: content} - }, - getTokenCounts() { - return this.map.reduce((result, message) => { - result[message.identifier] = message.message ? this.tokenHandler?.count(message.message) : 0; - return result; - }, {}); - }, - getTotalTokenCount() { - return Object.values(this.getTokenCounts()).reduce((accumulator, currentValue) => accumulator + currentValue, 0) - }, - getChat() { - return this.map.reduce((chat, item) => { - if (!item || !item.message || (false === Array.isArray(item.message) && !item.message.content)) return chat; - if (true === Array.isArray(item.message)) { - if (0 !== item.message.length) chat.push(...item.message); - } else chat.push(item.message); - return chat; - }, []); - }, - } - } -}; - function PromptManagerModule() { this.configuration = { prefix: '', @@ -555,18 +459,16 @@ PromptManagerModule.prototype.clearEditForm = function () { * Generates and returns a new ChatCompletion object based on the active character's prompt list. * @returns {Object} A ChatCompletion object */ -PromptManagerModule.prototype.getChatCompletion = function () { - const chatCompletion = ChatCompletion.new(); - chatCompletion.tokenHandler = this.getTokenHandler(); - +PromptManagerModule.prototype.getOrderedPromptList = function () { const promptList = this.getPromptListByCharacter(this.activeCharacter); + const assembledPromptList = []; promptList.forEach(entry => { const chatMessage = this.preparePrompt(this.getPromptById(entry.identifier)) - if (true === entry.enabled) chatCompletion.add(entry.identifier, chatMessage); + if (true === entry.enabled) assembledPromptList.push({identifier: entry.identifier, ...chatMessage}); }) - return chatCompletion; + return assembledPromptList; } // Empties, then re-assembles the container containing the prompt list. @@ -925,6 +827,5 @@ export { PromptManagerModule, openAiDefaultPrompts, openAiDefaultPromptLists, - defaultPromptManagerSettings, - IdentifierNotFoundError + defaultPromptManagerSettings }; diff --git a/public/scripts/openai.js b/public/scripts/openai.js index 3482f16d1..75e41270c 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -27,7 +27,7 @@ import { import {groups, selected_group} from "./group-chats.js"; import { - defaultPromptManagerSettings, IdentifierNotFoundError, + defaultPromptManagerSettings, openAiDefaultPromptLists, openAiDefaultPrompts, PromptManagerModule as PromptManager @@ -66,7 +66,10 @@ export { setOpenAIOnlineStatus, getChatCompletionModel, countTokens, - TokenHandler + TokenHandler, + IdentifierNotFoundError, + Message, + MessageCollection } let openai_msgs = []; @@ -365,20 +368,59 @@ function formatWorldInfo(value) { } async function prepareOpenAIMessages({ name2, charDescription, charPersonality, Scenario, worldInfoBefore, worldInfoAfter, bias, type, quietPrompt, extensionPrompts, cyclePrompt } = {}) { - const chatCompletion = promptManager.getChatCompletion(); + const promptList = promptManager.getOrderedPromptList(); + const promptIndex = (identifier) => promptList.findIndex(prompt => prompt.identifier === identifier); + const getMessageWithIndex = (identifier) => { + const index = promptIndex(identifier); + const prompt = promptList[index]; + const msg = new Message(prompt.role, prompt.content, prompt.identifier); + return {message: msg, index: index} + } - // Prepare messages - const worldInfoBeforeMessage = chatCompletion.makeSystemMessage(formatWorldInfo(worldInfoBefore)); - const worldInfoAfterMessage = chatCompletion.makeSystemMessage(formatWorldInfo(worldInfoAfter)); - const charDescriptionMessage = chatCompletion.makeSystemMessage(substituteParams(charDescription)); + const chatCompletion = ChatCompletion.new(); + chatCompletion.tokenBudget = promptManager.serviceSettings.openai_max_context - promptManager.serviceSettings.amount_gen; - const charPersonalityMessage = chatCompletion.makeSystemMessage( - name2 + 's personality: ' + substituteParams(charPersonality) - ); + const main = getMessageWithIndex('main'); + const nsfw = getMessageWithIndex('nsfw'); + const jailbreak = getMessageWithIndex('jailbreak'); - const scenarioMessage = chatCompletion.makeSystemMessage( - 'Circumstances and context of the dialogue: ' + substituteParams(Scenario) - ); + const worldInfoBeforeMsg = new Message('system', formatWorldInfo(worldInfoBefore), 'worldInfoBefore'); + const worldInfoAfterMsg = new Message('system', formatWorldInfo(worldInfoAfter), 'worldInfoAfter'); + const charDescriptionMsg = new Message('system', substituteParams(charDescription), 'charDescription'); + const charPersonalityMsg = new Message('system', `${name2}'s personality: ${substituteParams(charPersonality)}`, 'charPersonality'); + const scenarioMsg = new Message('system', `Circumstances and context of the dialogue: ${substituteParams(Scenario)}`, 'scenario'); + + chatCompletion + .add(main.message, main.index) + .add(nsfw.message, nsfw.index) + .add(jailbreak.message, jailbreak.index) + .add(worldInfoBeforeMsg, promptIndex('worldInfoBefore')) + .add(worldInfoAfterMsg, promptIndex('worldInfoAfter')) + .add(charDescriptionMsg, promptIndex('charDescription')) + .add(charPersonalityMsg, promptIndex('charPersonality')) + .add(scenarioMsg, promptIndex('scenario')); + + + // Chat History + const startNewChatPrompt = selected_group ? '[Start a new group chat. Group members: ${names}]' : '[Start a new Chat]'; + chatCompletion.add(new Message('system', startNewChatPrompt, 'newMainChat' ), promptIndex('newMainChat')); + + const chatHistoryIndex = promptIndex('chatHistory'); + [...openai_msgs].reverse().forEach((prompt, index) => { + const message = new Message(prompt.role, prompt.content, 'chatHistory-' + index); + if (chatCompletion.canAfford(message)) chatCompletion.insert(message, chatHistoryIndex); + }); + + const chat = chatCompletion.getChat(); + openai_messages_count = chat.filter(x => x.role === "user" || x.role === "assistant").length; + + return [chat, false]; + + /** + chatCompletion.add(new Message('system', formatWorldInfo(worldInfoBefore), 'worldInfoBefore')); + + console.log(chatCompletion.message); + return; const newChatMessage = chatCompletion.makeSystemMessage('[Start new chat]'); const chatMessages = openai_msgs; @@ -449,10 +491,14 @@ async function prepareOpenAIMessages({ name2, charDescription, charPersonality, {...tokenHandler.getCounts(), ...chatCompletion.getTokenCounts()} ); + console.log(chatCompletion.map) + const openai_msgs_tosend = chatCompletion.getChat(); + console.log(openai_msgs_tosend) openai_messages_count = openai_msgs_tosend.filter(x => x.role === "user" || x.role === "assistant").length; return [openai_msgs_tosend, false]; + **/ } function getGroupMembers(activeGroup) { @@ -921,13 +967,16 @@ class TokenHandler { } count(messages, full, type) { - //console.log(messages); const token_count = this.countTokenFn(messages, full); this.counts[type] += token_count; return token_count; } + getTokensForIdentifier(identifier) { + return this.counts[identifier] ?? 0; + } + getTotal() { return Object.values(this.counts).reduce((a, b) => a + b); } @@ -937,8 +986,6 @@ class TokenHandler { } } -const tokenHandler = new TokenHandler(countTokens); - function countTokens(messages, full = false) { let chatId = 'undefined'; @@ -993,6 +1040,109 @@ function countTokens(messages, full = false) { return token_count; } +const tokenHandler = new TokenHandler(countTokens); + +// Thrown by ChatCompletion when a requested prompt couldn't be found. +class IdentifierNotFoundError extends Error { + constructor(identifier) { + super(`Identifier ${identifier} not found.`); + this.name = 'IdentifierNotFoundError'; + } +} + +class TokenBudgetExceededError extends Error { + constructor(identifier = '') { + super(`Token budged exceeded. Message: ${identifier}`); + this.name = 'TokenBudgetExceeded'; + } +} + +class Message { + identifier; role; content; tokens; + constructor(role, content, identifier = null) { + this.identifier = identifier; + this.role = role; + this.content = content; + this.tokens = tokenHandler.count(this); + } + + getTokens() {return this.tokens}; +} + +class MessageCollection extends Array { + identifier; + constructor(identifier, ...items) { + for(let item of items) { + if(!(item instanceof Message || item instanceof MessageCollection)) { + throw new Error('Only Message and MessageCollection instances can be added to MessageCollection'); + } + } + super(...items); + this.identifier = identifier; + } + + getTokens() { + return this.reduce((tokens, message) => tokens + message.getTokens(), 0); + } +} + +/** + * OpenAI API chat completion representation + * const map = [{identifier: 'example', message: {role: 'system', content: 'exampleContent'}}, ...]; + * + * @see https://platform.openai.com/docs/guides/gpt/chat-completions-api + */ +const ChatCompletion = { + new() { + return { + tokenBudget: 4095, + messages: new MessageCollection(), + add(message, position = null) { + if (!(message instanceof Message)) throw Error('Invalid argument type'); + if (!message.content) return this; + if (false === this.canAfford(message)) throw new TokenBudgetExceededError(message.identifier); + + if (message instanceof MessageCollection) message.forEach(item => this.add(item, position)); + + if (position) this.messages[position] = message; + else this.messages.push(message); + + this.tokenBudget -= message.getTokens(); + + this.log(`Added ${message.identifier}. Remaining tokens: ${this.tokenBudget}`); + + return this; + }, + insert(message, position) { + if (!(message instanceof Message)) throw Error('Invalid argument type'); + if (!message.content) return this; + if (false === this.canAfford(message)) throw new TokenBudgetExceededError(message.identifier); + + this.messages.splice(position, 0, message) + + this.tokenBudget -= message.getTokens(); + + this.log(`Added ${message.identifier}. Remaining tokens: ${this.tokenBudget}`); + }, + canAfford(message) { + return 0 < this.tokenBudget - message.getTokens(); + }, + log(output) { + if (power_user.console_log_prompts) console.log('[ChatCompletion] ' + output); + }, + getTotalTokenCount() { + return this.messages.getTokens(); + }, + getChat() { + return this.messages.reduce((chat, message) => { + if (message.content) chat.push({role: message.role, content: message.content}); + return chat; + }, []); + }, + } + } +}; + export function getTokenizerModel() { // OpenAI models always provide their own tokenizer if (oai_settings.chat_completion_source == chat_completion_sources.OPENAI) {