Add prototype for prompt manager token management

This commit is contained in:
maver
2023-06-13 20:48:06 +02:00
parent 05f7e5677d
commit 18641ea3d2
2 changed files with 171 additions and 120 deletions

View File

@ -27,7 +27,7 @@ import {
import {groups, selected_group} from "./group-chats.js";
import {
defaultPromptManagerSettings, IdentifierNotFoundError,
defaultPromptManagerSettings,
openAiDefaultPromptLists,
openAiDefaultPrompts,
PromptManagerModule as PromptManager
@ -66,7 +66,10 @@ export {
setOpenAIOnlineStatus,
getChatCompletionModel,
countTokens,
TokenHandler
TokenHandler,
IdentifierNotFoundError,
Message,
MessageCollection
}
let openai_msgs = [];
@ -365,20 +368,59 @@ function formatWorldInfo(value) {
}
async function prepareOpenAIMessages({ name2, charDescription, charPersonality, Scenario, worldInfoBefore, worldInfoAfter, bias, type, quietPrompt, extensionPrompts, cyclePrompt } = {}) {
const chatCompletion = promptManager.getChatCompletion();
const promptList = promptManager.getOrderedPromptList();
const promptIndex = (identifier) => promptList.findIndex(prompt => prompt.identifier === identifier);
const getMessageWithIndex = (identifier) => {
const index = promptIndex(identifier);
const prompt = promptList[index];
const msg = new Message(prompt.role, prompt.content, prompt.identifier);
return {message: msg, index: index}
}
// Prepare messages
const worldInfoBeforeMessage = chatCompletion.makeSystemMessage(formatWorldInfo(worldInfoBefore));
const worldInfoAfterMessage = chatCompletion.makeSystemMessage(formatWorldInfo(worldInfoAfter));
const charDescriptionMessage = chatCompletion.makeSystemMessage(substituteParams(charDescription));
const chatCompletion = ChatCompletion.new();
chatCompletion.tokenBudget = promptManager.serviceSettings.openai_max_context - promptManager.serviceSettings.amount_gen;
const charPersonalityMessage = chatCompletion.makeSystemMessage(
name2 + 's personality: ' + substituteParams(charPersonality)
);
const main = getMessageWithIndex('main');
const nsfw = getMessageWithIndex('nsfw');
const jailbreak = getMessageWithIndex('jailbreak');
const scenarioMessage = chatCompletion.makeSystemMessage(
'Circumstances and context of the dialogue: ' + substituteParams(Scenario)
);
const worldInfoBeforeMsg = new Message('system', formatWorldInfo(worldInfoBefore), 'worldInfoBefore');
const worldInfoAfterMsg = new Message('system', formatWorldInfo(worldInfoAfter), 'worldInfoAfter');
const charDescriptionMsg = new Message('system', substituteParams(charDescription), 'charDescription');
const charPersonalityMsg = new Message('system', `${name2}'s personality: ${substituteParams(charPersonality)}`, 'charPersonality');
const scenarioMsg = new Message('system', `Circumstances and context of the dialogue: ${substituteParams(Scenario)}`, 'scenario');
chatCompletion
.add(main.message, main.index)
.add(nsfw.message, nsfw.index)
.add(jailbreak.message, jailbreak.index)
.add(worldInfoBeforeMsg, promptIndex('worldInfoBefore'))
.add(worldInfoAfterMsg, promptIndex('worldInfoAfter'))
.add(charDescriptionMsg, promptIndex('charDescription'))
.add(charPersonalityMsg, promptIndex('charPersonality'))
.add(scenarioMsg, promptIndex('scenario'));
// Chat History
const startNewChatPrompt = selected_group ? '[Start a new group chat. Group members: ${names}]' : '[Start a new Chat]';
chatCompletion.add(new Message('system', startNewChatPrompt, 'newMainChat' ), promptIndex('newMainChat'));
const chatHistoryIndex = promptIndex('chatHistory');
[...openai_msgs].reverse().forEach((prompt, index) => {
const message = new Message(prompt.role, prompt.content, 'chatHistory-' + index);
if (chatCompletion.canAfford(message)) chatCompletion.insert(message, chatHistoryIndex);
});
const chat = chatCompletion.getChat();
openai_messages_count = chat.filter(x => x.role === "user" || x.role === "assistant").length;
return [chat, false];
/**
chatCompletion.add(new Message('system', formatWorldInfo(worldInfoBefore), 'worldInfoBefore'));
console.log(chatCompletion.message);
return;
const newChatMessage = chatCompletion.makeSystemMessage('[Start new chat]');
const chatMessages = openai_msgs;
@ -449,10 +491,14 @@ async function prepareOpenAIMessages({ name2, charDescription, charPersonality,
{...tokenHandler.getCounts(), ...chatCompletion.getTokenCounts()}
);
console.log(chatCompletion.map)
const openai_msgs_tosend = chatCompletion.getChat();
console.log(openai_msgs_tosend)
openai_messages_count = openai_msgs_tosend.filter(x => x.role === "user" || x.role === "assistant").length;
return [openai_msgs_tosend, false];
**/
}
function getGroupMembers(activeGroup) {
@ -921,13 +967,16 @@ class TokenHandler {
}
count(messages, full, type) {
//console.log(messages);
const token_count = this.countTokenFn(messages, full);
this.counts[type] += token_count;
return token_count;
}
getTokensForIdentifier(identifier) {
return this.counts[identifier] ?? 0;
}
getTotal() {
return Object.values(this.counts).reduce((a, b) => a + b);
}
@ -937,8 +986,6 @@ class TokenHandler {
}
}
const tokenHandler = new TokenHandler(countTokens);
function countTokens(messages, full = false) {
let chatId = 'undefined';
@ -993,6 +1040,109 @@ function countTokens(messages, full = false) {
return token_count;
}
const tokenHandler = new TokenHandler(countTokens);
// Thrown by ChatCompletion when a requested prompt couldn't be found.
class IdentifierNotFoundError extends Error {
constructor(identifier) {
super(`Identifier ${identifier} not found.`);
this.name = 'IdentifierNotFoundError';
}
}
class TokenBudgetExceededError extends Error {
constructor(identifier = '') {
super(`Token budged exceeded. Message: ${identifier}`);
this.name = 'TokenBudgetExceeded';
}
}
class Message {
identifier; role; content; tokens;
constructor(role, content, identifier = null) {
this.identifier = identifier;
this.role = role;
this.content = content;
this.tokens = tokenHandler.count(this);
}
getTokens() {return this.tokens};
}
class MessageCollection extends Array {
identifier;
constructor(identifier, ...items) {
for(let item of items) {
if(!(item instanceof Message || item instanceof MessageCollection)) {
throw new Error('Only Message and MessageCollection instances can be added to MessageCollection');
}
}
super(...items);
this.identifier = identifier;
}
getTokens() {
return this.reduce((tokens, message) => tokens + message.getTokens(), 0);
}
}
/**
* OpenAI API chat completion representation
* const map = [{identifier: 'example', message: {role: 'system', content: 'exampleContent'}}, ...];
*
* @see https://platform.openai.com/docs/guides/gpt/chat-completions-api
*/
const ChatCompletion = {
new() {
return {
tokenBudget: 4095,
messages: new MessageCollection(),
add(message, position = null) {
if (!(message instanceof Message)) throw Error('Invalid argument type');
if (!message.content) return this;
if (false === this.canAfford(message)) throw new TokenBudgetExceededError(message.identifier);
if (message instanceof MessageCollection) message.forEach(item => this.add(item, position));
if (position) this.messages[position] = message;
else this.messages.push(message);
this.tokenBudget -= message.getTokens();
this.log(`Added ${message.identifier}. Remaining tokens: ${this.tokenBudget}`);
return this;
},
insert(message, position) {
if (!(message instanceof Message)) throw Error('Invalid argument type');
if (!message.content) return this;
if (false === this.canAfford(message)) throw new TokenBudgetExceededError(message.identifier);
this.messages.splice(position, 0, message)
this.tokenBudget -= message.getTokens();
this.log(`Added ${message.identifier}. Remaining tokens: ${this.tokenBudget}`);
},
canAfford(message) {
return 0 < this.tokenBudget - message.getTokens();
},
log(output) {
if (power_user.console_log_prompts) console.log('[ChatCompletion] ' + output);
},
getTotalTokenCount() {
return this.messages.getTokens();
},
getChat() {
return this.messages.reduce((chat, message) => {
if (message.content) chat.push({role: message.role, content: message.content});
return chat;
}, []);
},
}
}
};
export function getTokenizerModel() {
// OpenAI models always provide their own tokenizer
if (oai_settings.chat_completion_source == chat_completion_sources.OPENAI) {