Introduction of additional helper classes, refactoring

This commit is contained in:
maver
2023-06-14 22:36:14 +02:00
parent 18641ea3d2
commit 8ae2c80358
2 changed files with 225 additions and 92 deletions

View File

@ -367,50 +367,70 @@ function formatWorldInfo(value) {
return stringFormat(oai_settings.wi_format, value);
}
async function prepareOpenAIMessages({ name2, charDescription, charPersonality, Scenario, worldInfoBefore, worldInfoAfter, bias, type, quietPrompt, extensionPrompts, cyclePrompt } = {}) {
const promptList = promptManager.getOrderedPromptList();
const promptIndex = (identifier) => promptList.findIndex(prompt => prompt.identifier === identifier);
const getMessageWithIndex = (identifier) => {
const index = promptIndex(identifier);
const prompt = promptList[index];
const msg = new Message(prompt.role, prompt.content, prompt.identifier);
return {message: msg, index: index}
}
async function prepareOpenAIMessages({
name2,
charDescription,
charPersonality,
Scenario,
worldInfoBefore,
worldInfoAfter,
bias,
type,
quietPrompt,
extensionPrompts,
cyclePrompt
} = {}) {
const prompts = promptManager.getPromptCollection();
const chatCompletion = new ChatCompletion();
const chatCompletion = ChatCompletion.new();
chatCompletion.tokenBudget = promptManager.serviceSettings.openai_max_context - promptManager.serviceSettings.amount_gen;
const main = getMessageWithIndex('main');
const nsfw = getMessageWithIndex('nsfw');
const jailbreak = getMessageWithIndex('jailbreak');
if (power_user.console_log_prompts) chatCompletion.enableLogging();
const worldInfoBeforeMsg = new Message('system', formatWorldInfo(worldInfoBefore), 'worldInfoBefore');
const worldInfoAfterMsg = new Message('system', formatWorldInfo(worldInfoAfter), 'worldInfoAfter');
const charDescriptionMsg = new Message('system', substituteParams(charDescription), 'charDescription');
const charPersonalityMsg = new Message('system', `${name2}'s personality: ${substituteParams(charPersonality)}`, 'charPersonality');
const scenarioMsg = new Message('system', `Circumstances and context of the dialogue: ${substituteParams(Scenario)}`, 'scenario');
// Helper functions
const createMessageCollection = (role, content, identifier) => MessageCollection.fromPrompt(new Prompt({role, content, identifier}));
const addMessageToChatCompletion = (role, content, identifier) => {
chatCompletion.add(createMessageCollection(role, content, identifier), prompts.index(identifier));
};
chatCompletion
.add(main.message, main.index)
.add(nsfw.message, nsfw.index)
.add(jailbreak.message, jailbreak.index)
.add(worldInfoBeforeMsg, promptIndex('worldInfoBefore'))
.add(worldInfoAfterMsg, promptIndex('worldInfoAfter'))
.add(charDescriptionMsg, promptIndex('charDescription'))
.add(charPersonalityMsg, promptIndex('charPersonality'))
.add(scenarioMsg, promptIndex('scenario'));
addMessageToChatCompletion('system', formatWorldInfo(worldInfoBefore), 'worldInfoBefore');
addMessageToChatCompletion('system', formatWorldInfo(worldInfoAfter), 'worldInfoAfter');
addMessageToChatCompletion('system', substituteParams(charDescription), 'charDescription');
addMessageToChatCompletion('system', `${name2}'s personality: ${substituteParams(charPersonality)}`, 'charPersonality');
addMessageToChatCompletion('system', `Circumstances and context of the dialogue: ${substituteParams(Scenario)}`, 'scenario');
const optionalSystemPrompts = ['nsfw', 'jailbreak'];
const userPrompts = prompts.collection
.filter((prompt) => false === prompt.system_prompt)
.reduce((acc, prompt) => {
acc.push(prompt.identifier)
return acc;
}, []);
// Add optional prompts if they exist
[...optionalSystemPrompts, ...userPrompts].forEach(identifier => {
if (prompts.has(identifier)) {
chatCompletion.add(MessageCollection.fromPrompt(prompts.get(identifier)), prompts.index(identifier));
}
});
// Chat History
const startNewChatPrompt = selected_group ? '[Start a new group chat. Group members: ${names}]' : '[Start a new Chat]';
chatCompletion.add(new Message('system', startNewChatPrompt, 'newMainChat' ), promptIndex('newMainChat'));
const startNewChat = selected_group ? '[Start a new group chat. Group members: ${names}]' : '[Start a new Chat]';
chatCompletion.add(new MessageCollection('chatHistory'), prompts.index('chatHistory'));
chatCompletion.insert(new Message('system', startNewChat, 'newMainChat'), 'chatHistory');
const chatHistoryIndex = promptIndex('chatHistory');
// Insert chat messages
[...openai_msgs].reverse().forEach((prompt, index) => {
const message = new Message(prompt.role, prompt.content, 'chatHistory-' + index);
if (chatCompletion.canAfford(message)) chatCompletion.insert(message, chatHistoryIndex);
const chatMessage = new Message(prompt.role, prompt.content, 'chatHistory-' + index);
if (chatCompletion.canAfford(chatMessage)) {
chatCompletion.insert(chatMessage, 'chatHistory');
}
});
// Insert chat message examples if there's enough budget
//ToDo
const chat = chatCompletion.getChat();
openai_messages_count = chat.filter(x => x.role === "user" || x.role === "assistant").length;
@ -1058,8 +1078,8 @@ class TokenBudgetExceededError extends Error {
}
class Message {
identifier; role; content; tokens;
constructor(role, content, identifier = null) {
tokens; identifier; role; content;
constructor(role, content, identifier) {
this.identifier = identifier;
this.role = role;
this.content = content;
@ -1067,9 +1087,14 @@ class Message {
}
getTokens() {return this.tokens};
static fromPrompt(prompt) {
return new Message(prompt.role, prompt.content, prompt.identifier);
}
}
class MessageCollection extends Array {
class MessageCollection {
collection = [];
identifier;
constructor(identifier, ...items) {
for(let item of items) {
@ -1077,12 +1102,17 @@ class MessageCollection extends Array {
throw new Error('Only Message and MessageCollection instances can be added to MessageCollection');
}
}
super(...items);
this.collection.push(...items);
this.identifier = identifier;
}
getTokens() {
return this.reduce((tokens, message) => tokens + message.getTokens(), 0);
return this.collection.reduce((tokens, message) => tokens + message.getTokens(), 0);
}
static fromPrompt(prompt) {
return new MessageCollection(prompt.identifier, Message.fromPrompt(prompt));
}
}
@ -1092,56 +1122,110 @@ class MessageCollection extends Array {
*
* @see https://platform.openai.com/docs/guides/gpt/chat-completions-api
*/
const ChatCompletion = {
new() {
return {
tokenBudget: 4095,
messages: new MessageCollection(),
add(message, position = null) {
if (!(message instanceof Message)) throw Error('Invalid argument type');
if (!message.content) return this;
if (false === this.canAfford(message)) throw new TokenBudgetExceededError(message.identifier);
class ChatCompletion {
constructor() {
this.tokenBudget = 0;
this.messages = new MessageCollection();
this.loggingEnabled = false;
}
if (message instanceof MessageCollection) message.forEach(item => this.add(item, position));
add(collection, position = null) {
this.validateMessageCollection(collection);
this.checkTokenBudget(collection, collection.identifier);
if (position) this.messages[position] = message;
else this.messages.push(message);
if (position) {
this.messages.collection[position] = collection;
} else {
this.messages.collection.push(collection);
}
this.tokenBudget -= message.getTokens();
this.decreaseTokenBudgetBy(collection.getTokens());
this.log(`Added ${collection.identifier}. Remaining tokens: ${this.tokenBudget}`);
this.log(`Added ${message.identifier}. Remaining tokens: ${this.tokenBudget}`);
return this;
}
return this;
},
insert(message, position) {
if (!(message instanceof Message)) throw Error('Invalid argument type');
if (!message.content) return this;
if (false === this.canAfford(message)) throw new TokenBudgetExceededError(message.identifier);
insert(message, identifier) {
this.validateMessage(message);
this.checkTokenBudget(message, message.identifier);
this.messages.splice(position, 0, message)
this.tokenBudget -= message.getTokens();
this.log(`Added ${message.identifier}. Remaining tokens: ${this.tokenBudget}`);
},
canAfford(message) {
return 0 < this.tokenBudget - message.getTokens();
},
log(output) {
if (power_user.console_log_prompts) console.log('[ChatCompletion] ' + output);
},
getTotalTokenCount() {
return this.messages.getTokens();
},
getChat() {
return this.messages.reduce((chat, message) => {
if (message.content) chat.push({role: message.role, content: message.content});
return chat;
}, []);
},
const index = this.findMessageIndex(identifier);
if (message.content) {
this.messages.collection[index].collection.push(message);
this.decreaseTokenBudgetBy(message.getTokens());
this.log(`Added ${message.identifier}. Remaining tokens: ${this.tokenBudget}`);
}
}
};
canAfford(message) {
return 0 < this.tokenBudget - message.getTokens();
}
has(identifier) {
return this.messages.collection.some(message => message.identifier === identifier);
}
getTotalTokenCount() {
return this.messages.getTokens();
}
getChat() {
const chat = [];
for (let item of this.messages.collection) {
if (item instanceof MessageCollection) {
const messages = item.collection.reduce((acc, message) => {
if (message.content) acc.push({role: message.role, content: message.content});
return acc;
}, []);
chat.push(...messages);
}
}
return chat;
}
log(output) {
if (this.loggingEnabled) console.log('[ChatCompletion] ' + output);
}
enableLogging() {
this.loggingEnabled = true;
}
disableLogging() {
this.loggingEnabled = false;
}
// Move validation to its own method for readability
validateMessageCollection(collection) {
if (!(collection instanceof MessageCollection)) {
throw new Error('Argument must be an instance of MessageCollection');
}
}
validateMessage(message) {
if (!(message instanceof Message)) {
throw new Error('Argument must be an instance of Message');
}
}
checkTokenBudget(message, identifier) {
if (!this.canAfford(message)) {
throw new TokenBudgetExceededError(identifier);
}
}
decreaseTokenBudgetBy(tokens) {
this.tokenBudget -= tokens;
}
findMessageIndex(identifier) {
const index = this.messages.collection.findIndex(item => item?.identifier === identifier);
if (index < 0) {
throw new IdentifierNotFoundError(identifier);
}
return index;
}
}
export function getTokenizerModel() {
// OpenAI models always provide their own tokenizer