Chat Completion: switch to async token handling

This commit is contained in:
Cohee 2024-10-12 01:07:36 +03:00
parent b0d0f2111b
commit 8e082e622b
2 changed files with 104 additions and 62 deletions

View File

@ -315,7 +315,7 @@ class PromptManager {
*/
init(moduleConfiguration, serviceSettings) {
this.configuration = Object.assign(this.configuration, moduleConfiguration);
this.tokenHandler = this.tokenHandler || new TokenHandler();
this.tokenHandler = this.tokenHandler || new TokenHandler(() => { throw new Error('Token handler not set'); });
this.serviceSettings = serviceSettings;
this.containerElement = document.getElementById(this.configuration.containerIdentifier);

View File

@ -60,7 +60,7 @@ import {
resetScrollHeight,
stringFormat,
} from './utils.js';
import { countTokensOpenAI, getTokenizerModel } from './tokenizers.js';
import { countTokensOpenAIAsync, getTokenizerModel } from './tokenizers.js';
import { isMobile } from './RossAscends-mods.js';
import { saveLogprobsForActiveMessage } from './logprobs.js';
import { SlashCommandParser } from './slash-commands/SlashCommandParser.js';
@ -671,14 +671,14 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul
// Reserve budget for new chat message
const newChat = selected_group ? oai_settings.new_group_chat_prompt : oai_settings.new_chat_prompt;
const newChatMessage = new Message('system', substituteParams(newChat), 'newMainChat');
const newChatMessage = await Message.createAsync('system', substituteParams(newChat), 'newMainChat');
chatCompletion.reserveBudget(newChatMessage);
// Reserve budget for group nudge
let groupNudgeMessage = null;
const noGroupNudgeTypes = ['impersonate'];
if (selected_group && prompts.has('groupNudge') && !noGroupNudgeTypes.includes(type)) {
groupNudgeMessage = Message.fromPrompt(prompts.get('groupNudge'));
groupNudgeMessage = await Message.fromPromptAsync(prompts.get('groupNudge'));
chatCompletion.reserveBudget(groupNudgeMessage);
}
@ -693,12 +693,12 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul
};
const continuePrompt = new Prompt(promptObject);
const preparedPrompt = promptManager.preparePrompt(continuePrompt);
continueMessage = Message.fromPrompt(preparedPrompt);
continueMessage = await Message.fromPromptAsync(preparedPrompt);
chatCompletion.reserveBudget(continueMessage);
}
const lastChatPrompt = messages[messages.length - 1];
const message = new Message('user', oai_settings.send_if_empty, 'emptyUserMessageReplacement');
const message = await Message.createAsync('user', oai_settings.send_if_empty, 'emptyUserMessageReplacement');
if (lastChatPrompt && lastChatPrompt.role === 'assistant' && oai_settings.send_if_empty && chatCompletion.canAfford(message)) {
chatCompletion.insert(message, 'chatHistory');
}
@ -715,11 +715,11 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul
// We do not want to mutate the prompt
const prompt = new Prompt(chatPrompt);
prompt.identifier = `chatHistory-${messages.length - index}`;
const chatMessage = Message.fromPrompt(promptManager.preparePrompt(prompt));
const chatMessage = await Message.fromPromptAsync(promptManager.preparePrompt(prompt));
if (promptManager.serviceSettings.names_behavior === character_names_behavior.COMPLETION && prompt.name) {
const messageName = promptManager.isValidName(prompt.name) ? prompt.name : promptManager.sanitizeName(prompt.name);
chatMessage.setName(messageName);
await chatMessage.setName(messageName);
}
if (imageInlining && chatPrompt.image) {
@ -729,9 +729,9 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul
if (canUseTools && Array.isArray(chatPrompt.invocations)) {
/** @type {import('./tool-calling.js').ToolInvocation[]} */
const invocations = chatPrompt.invocations;
const toolCallMessage = new Message(chatMessage.role, undefined, 'toolCall-' + chatMessage.identifier);
const toolResultMessages = invocations.slice().reverse().map((invocation) => new Message('tool', invocation.result || '[No content]', invocation.id));
toolCallMessage.setToolCalls(invocations);
const toolCallMessage = await Message.createAsync(chatMessage.role, undefined, 'toolCall-' + chatMessage.identifier);
const toolResultMessages = await Promise.all(invocations.slice().reverse().map((invocation) => Message.createAsync('tool', invocation.result || '[No content]', invocation.id)));
await toolCallMessage.setToolCalls(invocations);
if (chatCompletion.canAffordAll([toolCallMessage, ...toolResultMessages])) {
for (const resultMessage of toolResultMessages) {
chatCompletion.insertAtStart(resultMessage, 'chatHistory');
@ -748,7 +748,8 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul
if (type === 'continue' && oai_settings.continue_prefill && chatPrompt === firstNonInjected) {
// in case we are using continue_prefill and the latest message is an assistant message, we want to prepend the users assistant prefill on the message
if (chatPrompt.role === 'assistant') {
const collection = new MessageCollection('continuePrefill', new Message(chatMessage.role, substituteParams(oai_settings.assistant_prefill + '\n\n') + chatMessage.content, chatMessage.identifier));
const continueMessage = await Message.createAsync(chatMessage.role, substituteParams(oai_settings.assistant_prefill + '\n\n') + chatMessage.content, chatMessage.identifier);
const collection = new MessageCollection('continuePrefill', );
chatCompletion.add(collection, -1);
continue;
}
@ -787,15 +788,16 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul
* @param {ChatCompletion} chatCompletion - An instance of ChatCompletion class that will be populated with the prompts.
* @param {Object[]} messageExamples - Array containing all message examples.
*/
function populateDialogueExamples(prompts, chatCompletion, messageExamples) {
async function populateDialogueExamples(prompts, chatCompletion, messageExamples) {
if (!prompts.has('dialogueExamples')) {
return;
}
chatCompletion.add(new MessageCollection('dialogueExamples'), prompts.index('dialogueExamples'));
if (Array.isArray(messageExamples) && messageExamples.length) {
const newExampleChat = new Message('system', substituteParams(oai_settings.new_example_chat_prompt), 'newChat');
[...messageExamples].forEach((dialogue, dialogueIndex) => {
const newExampleChat = await Message.createAsync('system', substituteParams(oai_settings.new_example_chat_prompt), 'newChat');
for (const dialogue of [...messageExamples]) {
const dialogueIndex = messageExamples.indexOf(dialogue);
let examplesAdded = 0;
if (chatCompletion.canAfford(newExampleChat)) chatCompletion.insert(newExampleChat, 'dialogueExamples');
@ -806,8 +808,8 @@ function populateDialogueExamples(prompts, chatCompletion, messageExamples) {
const content = prompt.content || '';
const identifier = `dialogueExamples ${dialogueIndex}-${promptIndex}`;
const chatMessage = new Message(role, content, identifier);
chatMessage.setName(prompt.name);
const chatMessage = await Message.createAsync(role, content, identifier);
await chatMessage.setName(prompt.name);
if (!chatCompletion.canAfford(chatMessage)) {
break;
}
@ -818,7 +820,7 @@ function populateDialogueExamples(prompts, chatCompletion, messageExamples) {
if (0 === examplesAdded) {
chatCompletion.removeLastFrom('dialogueExamples');
}
});
}
}
}
@ -873,7 +875,7 @@ function getPromptRole(role) {
*/
async function populateChatCompletion(prompts, chatCompletion, { bias, quietPrompt, quietImage, type, cyclePrompt, messages, messageExamples }) {
// Helper function for preparing a prompt, that already exists within the prompt collection, for completion
const addToChatCompletion = (source, target = null) => {
const addToChatCompletion = async (source, target = null) => {
// We need the prompts array to determine a position for the source.
if (false === prompts.has(source)) return;
@ -891,30 +893,31 @@ async function populateChatCompletion(prompts, chatCompletion, { bias, quietProm
const index = target ? prompts.index(target) : prompts.index(source);
const collection = new MessageCollection(source);
collection.add(Message.fromPrompt(prompt));
const message = await Message.fromPromptAsync(prompt);
collection.add(message);
chatCompletion.add(collection, index);
};
chatCompletion.reserveBudget(3); // every reply is primed with <|start|>assistant<|message|>
// Character and world information
addToChatCompletion('worldInfoBefore');
addToChatCompletion('main');
addToChatCompletion('worldInfoAfter');
addToChatCompletion('charDescription');
addToChatCompletion('charPersonality');
addToChatCompletion('scenario');
addToChatCompletion('personaDescription');
await addToChatCompletion('worldInfoBefore');
await addToChatCompletion('main');
await addToChatCompletion('worldInfoAfter');
await addToChatCompletion('charDescription');
await addToChatCompletion('charPersonality');
await addToChatCompletion('scenario');
await addToChatCompletion('personaDescription');
// Collection of control prompts that will always be positioned last
chatCompletion.setOverriddenPrompts(prompts.overriddenPrompts);
const controlPrompts = new MessageCollection('controlPrompts');
const impersonateMessage = Message.fromPrompt(prompts.get('impersonate')) ?? null;
const impersonateMessage = await Message.fromPromptAsync(prompts.get('impersonate')) ?? null;
if (type === 'impersonate') controlPrompts.add(impersonateMessage);
// Add quiet prompt to control prompts
// This should always be last, even in control prompts. Add all further control prompts BEFORE this prompt
const quietPromptMessage = Message.fromPrompt(prompts.get('quietPrompt')) ?? null;
const quietPromptMessage = await Message.fromPromptAsync(prompts.get('quietPrompt')) ?? null;
if (quietPromptMessage && quietPromptMessage.content) {
if (isImageInliningSupported() && quietImage) {
await quietPromptMessage.addImage(quietImage);
@ -940,20 +943,23 @@ async function populateChatCompletion(prompts, chatCompletion, { bias, quietProm
return acc;
}, []);
[...systemPrompts, ...userRelativePrompts].forEach(identifier => addToChatCompletion(identifier));
for (const identifier of [...systemPrompts, ...userRelativePrompts]) {
await addToChatCompletion(identifier);
}
// Add enhance definition instruction
if (prompts.has('enhanceDefinitions')) addToChatCompletion('enhanceDefinitions');
if (prompts.has('enhanceDefinitions')) await addToChatCompletion('enhanceDefinitions');
// Bias
if (bias && bias.trim().length) addToChatCompletion('bias');
if (bias && bias.trim().length) await addToChatCompletion('bias');
// Tavern Extras - Summary
if (prompts.has('summary')) {
const summary = prompts.get('summary');
if (summary.position) {
chatCompletion.insert(Message.fromPrompt(summary), 'main', summary.position);
const message = await Message.fromPromptAsync(summary);
chatCompletion.insert(message, 'main', summary.position);
}
}
@ -962,7 +968,8 @@ async function populateChatCompletion(prompts, chatCompletion, { bias, quietProm
const authorsNote = prompts.get('authorsNote');
if (authorsNote.position) {
chatCompletion.insert(Message.fromPrompt(authorsNote), 'main', authorsNote.position);
const message = await Message.fromPromptAsync(authorsNote);
chatCompletion.insert(message, 'main', authorsNote.position);
}
}
@ -971,7 +978,8 @@ async function populateChatCompletion(prompts, chatCompletion, { bias, quietProm
const vectorsMemory = prompts.get('vectorsMemory');
if (vectorsMemory.position) {
chatCompletion.insert(Message.fromPrompt(vectorsMemory), 'main', vectorsMemory.position);
const message = await Message.fromPromptAsync(vectorsMemory);
chatCompletion.insert(message, 'main', vectorsMemory.position);
}
}
@ -980,7 +988,8 @@ async function populateChatCompletion(prompts, chatCompletion, { bias, quietProm
const vectorsDataBank = prompts.get('vectorsDataBank');
if (vectorsDataBank.position) {
chatCompletion.insert(Message.fromPrompt(vectorsDataBank), 'main', vectorsDataBank.position);
const message = await Message.fromPromptAsync(vectorsDataBank);
chatCompletion.insert(message, 'main', vectorsDataBank.position);
}
}
@ -989,13 +998,15 @@ async function populateChatCompletion(prompts, chatCompletion, { bias, quietProm
const smartContext = prompts.get('smartContext');
if (smartContext.position) {
chatCompletion.insert(Message.fromPrompt(smartContext), 'main', smartContext.position);
const message = await Message.fromPromptAsync(smartContext);
chatCompletion.insert(message, 'main', smartContext.position);
}
}
// Other relative extension prompts
for (const prompt of prompts.collection.filter(p => p.extension && p.position)) {
chatCompletion.insert(Message.fromPrompt(prompt), 'main', prompt.position);
const message = await Message.fromPromptAsync(prompt);
chatCompletion.insert(message, 'main', prompt.position);
}
// Pre-allocation of tokens for tool data
@ -1003,7 +1014,7 @@ async function populateChatCompletion(prompts, chatCompletion, { bias, quietProm
const toolData = {};
await ToolManager.registerFunctionToolsOpenAI(toolData);
const toolMessage = [{ role: 'user', content: JSON.stringify(toolData) }];
const toolTokens = tokenHandler.count(toolMessage);
const toolTokens = await tokenHandler.countAsync(toolMessage);
chatCompletion.reserveBudget(toolTokens);
}
@ -1012,11 +1023,11 @@ async function populateChatCompletion(prompts, chatCompletion, { bias, quietProm
// Decide whether dialogue examples should always be added
if (power_user.pin_examples) {
populateDialogueExamples(prompts, chatCompletion, messageExamples);
await populateDialogueExamples(prompts, chatCompletion, messageExamples);
await populateChatHistory(messages, prompts, chatCompletion, type, cyclePrompt);
} else {
await populateChatHistory(messages, prompts, chatCompletion, type, cyclePrompt);
populateDialogueExamples(prompts, chatCompletion, messageExamples);
await populateDialogueExamples(prompts, chatCompletion, messageExamples);
}
chatCompletion.freeBudget(controlPrompts);
@ -1281,7 +1292,7 @@ export async function prepareOpenAIMessages({
promptManager.setChatCompletion(chatCompletion);
if (oai_settings.squash_system_messages && dryRun == false) {
chatCompletion.squashSystemMessages();
await chatCompletion.squashSystemMessages();
}
// All information is up-to-date, render.
@ -2127,8 +2138,11 @@ async function calculateLogitBias() {
}
class TokenHandler {
constructor(countTokenFn) {
this.countTokenFn = countTokenFn;
/**
* @param {(messages: object[] | object, full?: boolean) => Promise<number>} countTokenAsyncFn Function to count tokens
*/
constructor(countTokenAsyncFn) {
this.countTokenAsyncFn = countTokenAsyncFn;
this.counts = {
'start_chat': 0,
'prompt': 0,
@ -2157,8 +2171,15 @@ class TokenHandler {
this.counts[type] -= value;
}
count(messages, full, type) {
const token_count = this.countTokenFn(messages, full);
/**
* Count tokens for a message or messages.
* @param {object|any[]} messages Messages to count tokens for
* @param {boolean} [full] Count full tokens
* @param {string} [type] Identifier for the token count
* @returns {Promise<number>} The token count
*/
async countAsync(messages, full, type) {
const token_count = await this.countTokenAsyncFn(messages, full);
this.counts[type] += token_count;
return token_count;
@ -2178,7 +2199,7 @@ class TokenHandler {
}
const tokenHandler = new TokenHandler(countTokensOpenAI);
const tokenHandler = new TokenHandler(countTokensOpenAIAsync);
// Thrown by ChatCompletion when a requested prompt couldn't be found.
class IdentifierNotFoundError extends Error {
@ -2228,6 +2249,7 @@ class Message {
* @param {string} role - The role of the entity creating the message.
* @param {string} content - The actual content of the message.
* @param {string} identifier - A unique identifier for the message.
* @private Don't use this constructor directly. Use createAsync instead.
*/
constructor(role, content, identifier) {
this.identifier = identifier;
@ -2239,18 +2261,32 @@ class Message {
this.role = 'system';
}
if (typeof this.content === 'string' && this.content.length > 0) {
this.tokens = tokenHandler.count({ role: this.role, content: this.content });
} else {
this.tokens = 0;
this.tokens = 0;
}
/**
* Create a new Message instance.
* @param {string} role
* @param {string} content
* @param {string} identifier
* @returns {Promise<Message>} Message instance
*/
static async createAsync(role, content, identifier) {
const message = new Message(role, content, identifier);
if (typeof message.content === 'string' && message.content.length > 0) {
message.tokens = await tokenHandler.countAsync({ role: message.role, content: message.content });
}
return message;
}
/**
* Reconstruct the message from a tool invocation.
* @param {import('./tool-calling.js').ToolInvocation[]} invocations
* @param {import('./tool-calling.js').ToolInvocation[]} invocations - The tool invocations to reconstruct the message from.
* @returns {Promise<void>}
*/
setToolCalls(invocations) {
async setToolCalls(invocations) {
this.tool_calls = invocations.map(i => ({
id: i.id,
type: 'function',
@ -2259,12 +2295,17 @@ class Message {
name: i.name,
},
}));
this.tokens = tokenHandler.count({ role: this.role, tool_calls: JSON.stringify(this.tool_calls) });
this.tokens = await tokenHandler.countAsync({ role: this.role, tool_calls: JSON.stringify(this.tool_calls) });
}
setName(name) {
/**
* Add a name to the message.
* @param {string} name Name to set for the message.
* @returns {Promise<void>}
*/
async setName(name) {
this.name = name;
this.tokens = tokenHandler.count({ role: this.role, content: this.content, name: this.name });
this.tokens = await tokenHandler.countAsync({ role: this.role, content: this.content, name: this.name });
}
async addImage(image) {
@ -2356,13 +2397,13 @@ class Message {
}
/**
* Create a new Message instance from a prompt.
* Create a new Message instance from a prompt asynchronously.
* @static
* @param {Object} prompt - The prompt object.
* @returns {Message} A new instance of Message.
* @returns {Promise<Message>} A new instance of Message.
*/
static fromPrompt(prompt) {
return new Message(prompt.role, prompt.content, prompt.identifier);
static async fromPromptAsync(prompt) {
return Message.createAsync(prompt.role, prompt.content, prompt.identifier);
}
/**
@ -2488,8 +2529,9 @@ export class ChatCompletion {
/**
* Combines consecutive system messages into one if they have no name attached.
* @returns {Promise<void>}
*/
squashSystemMessages() {
async squashSystemMessages() {
const excludeList = ['newMainChat', 'newChat', 'groupNudge'];
this.messages.collection = this.messages.flatten();
@ -2509,7 +2551,7 @@ export class ChatCompletion {
if (shouldSquash(message)) {
if (lastMessage && shouldSquash(lastMessage)) {
lastMessage.content += '\n' + message.content;
lastMessage.tokens = tokenHandler.count({ role: lastMessage.role, content: lastMessage.content });
lastMessage.tokens = await tokenHandler.countAsync({ role: lastMessage.role, content: lastMessage.content });
}
else {
squashedMessages.push(message);