Refactor addMessageToChatCompletion

This commit is contained in:
maver 2023-06-16 18:18:00 +02:00
parent 5270d261aa
commit 73e3001493

View File

@ -409,23 +409,25 @@ async function prepareOpenAIMessages({
if (power_user.console_log_prompts) chatCompletion.enableLogging(); if (power_user.console_log_prompts) chatCompletion.enableLogging();
// Helper functions // Helper functions
const createMessageCollection = (role, content, identifier) => MessageCollection.fromPrompt(new Prompt({role, content, identifier})); const addToChatCompletion = (role, content, identifier = null) => {
const addMessageToChatCompletion = (role, content, identifier) => { const collection = new MessageCollection(identifier)
chatCompletion.add(createMessageCollection(role, content, identifier), prompts.index(identifier)); if (role && content) collection.addItem(new Message(role, content, identifier));
const index = identifier ? prompts.index(identifier) : null;
chatCompletion.add(collection, index);
}; };
addMessageToChatCompletion('system', formatWorldInfo(worldInfoBefore), 'worldInfoBefore'); addToChatCompletion('system', formatWorldInfo(worldInfoBefore), 'worldInfoBefore');
addMessageToChatCompletion('system', formatWorldInfo(worldInfoAfter), 'worldInfoAfter'); addToChatCompletion('system', formatWorldInfo(worldInfoAfter), 'worldInfoAfter');
addMessageToChatCompletion('system', substituteParams(charDescription), 'charDescription'); addToChatCompletion('system', substituteParams(charDescription), 'charDescription');
addMessageToChatCompletion('system', `${name2}'s personality: ${substituteParams(charPersonality)}`, 'charPersonality'); addToChatCompletion('system', `${name2}'s personality: ${substituteParams(charPersonality)}`, 'charPersonality');
addMessageToChatCompletion('system', `Circumstances and context of the dialogue: ${substituteParams(Scenario)}`, 'scenario'); addToChatCompletion('system', `Circumstances and context of the dialogue: ${substituteParams(Scenario)}`, 'scenario');
// Add main prompt // Add main prompt
if (type === "impersonate") { if (type === "impersonate") {
const impersonate = substituteParams(oai_settings.impersonation_prompt); const impersonate = substituteParams(oai_settings.impersonation_prompt);
addMessageToChatCompletion('system', impersonate, 'main'); addToChatCompletion('system', impersonate, 'main');
} else { } else {
addMessageToChatCompletion('system', prompts.get('main').content, 'main'); addToChatCompletion('system', prompts.get('main').content, 'main');
} }
// Add managed system and user prompts // Add managed system and user prompts
@ -439,13 +441,15 @@ async function prepareOpenAIMessages({
[...systemPrompts, ...userPrompts].forEach(identifier => { [...systemPrompts, ...userPrompts].forEach(identifier => {
if (prompts.has(identifier)) { if (prompts.has(identifier)) {
chatCompletion.add(MessageCollection.fromPrompt(prompts.get(identifier)), prompts.index(identifier)); const prompt = prompts.get(identifier);
addToChatCompletion(prompt.role, prompt.content, identifier);
} }
}); });
// Add enhance definition instruction // Add enhance definition instruction
if (prompts.has('enhanceDefinitions')) { if (prompts.has('enhanceDefinitions')) {
chatCompletion.add(MessageCollection.fromPrompt(prompts.get('enhanceDefinitions')), prompts.index('enhanceDefinitions')); const prompt = prompts.get('enhanceDefinitions');
addToChatCompletion(prompt.role, prompt.content, identifier);
} }
// Insert nsfw avoidance prompt into main, if no nsfw prompt is present // Insert nsfw avoidance prompt into main, if no nsfw prompt is present
@ -457,11 +461,11 @@ async function prepareOpenAIMessages({
// Insert quiet prompt into main // Insert quiet prompt into main
if (quietPrompt) { if (quietPrompt) {
const quietPromptMessage = new Message('system', quietPrompt, 'quietPrompt'); const quietPromptMessage = new Message('system', quietPrompt, 'quietPrompt');
chatCompletion.insert(quietPromptMessage, 'main') chatCompletion.insert(quietPromptMessage, 'main');
} }
if (bias && bias.trim().length) { if (bias && bias.trim().length) {
addMessageToChatCompletion('system', bias, 'main'); addToChatCompletion('system', bias, 'main');
} }
// Add extension prompts // Add extension prompts
@ -481,7 +485,7 @@ async function prepareOpenAIMessages({
} }
// Chat History // Chat History
chatCompletion.add(new MessageCollection('chatHistory'), prompts.index('chatHistory')); addToChatCompletion(null, null, 'chatHistory');
const mainChat = selected_group ? '[Start a new group chat. Group members: ${names}]' : '[Start a new Chat]'; const mainChat = selected_group ? '[Start a new group chat. Group members: ${names}]' : '[Start a new Chat]';
const mainChatMessage = new Message('system', mainChat, 'newMainChat'); const mainChatMessage = new Message('system', mainChat, 'newMainChat');
@ -497,7 +501,7 @@ async function prepareOpenAIMessages({
} }
// Insert chat message examples if there's enough budget // Insert chat message examples if there's enough budget
chatCompletion.add(new MessageCollection('dialogueExamples'), prompts.index('dialogueExamples')); addToChatCompletion(null, null, 'dialogueExamples');
if (chatCompletion.canAfford(mainChatMessage)) { if (chatCompletion.canAfford(mainChatMessage)) {
// Insert dialogue examples messages // Insert dialogue examples messages
chatCompletion.insert(mainChatMessage, 'dialogueExamples'); chatCompletion.insert(mainChatMessage, 'dialogueExamples');
@ -1081,10 +1085,6 @@ class Message {
} }
getTokens() {return this.tokens}; getTokens() {return this.tokens};
static fromPrompt(prompt) {
return new Message(prompt.role, prompt.content, prompt.identifier);
}
} }
class MessageCollection { class MessageCollection {
@ -1101,12 +1101,12 @@ class MessageCollection {
this.identifier = identifier; this.identifier = identifier;
} }
getTokens() { addItem(item) {
return this.collection.reduce((tokens, message) => tokens + message.getTokens(), 0); this.collection.push(item);
} }
static fromPrompt(prompt) { getTokens() {
return new MessageCollection(prompt.identifier, Message.fromPrompt(prompt)); return this.collection.reduce((tokens, message) => tokens + message.getTokens(), 0);
} }
} }