mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-02-03 12:47:35 +01:00
Chat Completion: switch to async token handling
This commit is contained in:
parent
b0d0f2111b
commit
8e082e622b
@ -315,7 +315,7 @@ class PromptManager {
|
||||
*/
|
||||
init(moduleConfiguration, serviceSettings) {
|
||||
this.configuration = Object.assign(this.configuration, moduleConfiguration);
|
||||
this.tokenHandler = this.tokenHandler || new TokenHandler();
|
||||
this.tokenHandler = this.tokenHandler || new TokenHandler(() => { throw new Error('Token handler not set'); });
|
||||
this.serviceSettings = serviceSettings;
|
||||
this.containerElement = document.getElementById(this.configuration.containerIdentifier);
|
||||
|
||||
|
@ -60,7 +60,7 @@ import {
|
||||
resetScrollHeight,
|
||||
stringFormat,
|
||||
} from './utils.js';
|
||||
import { countTokensOpenAI, getTokenizerModel } from './tokenizers.js';
|
||||
import { countTokensOpenAIAsync, getTokenizerModel } from './tokenizers.js';
|
||||
import { isMobile } from './RossAscends-mods.js';
|
||||
import { saveLogprobsForActiveMessage } from './logprobs.js';
|
||||
import { SlashCommandParser } from './slash-commands/SlashCommandParser.js';
|
||||
@ -671,14 +671,14 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul
|
||||
|
||||
// Reserve budget for new chat message
|
||||
const newChat = selected_group ? oai_settings.new_group_chat_prompt : oai_settings.new_chat_prompt;
|
||||
const newChatMessage = new Message('system', substituteParams(newChat), 'newMainChat');
|
||||
const newChatMessage = await Message.createAsync('system', substituteParams(newChat), 'newMainChat');
|
||||
chatCompletion.reserveBudget(newChatMessage);
|
||||
|
||||
// Reserve budget for group nudge
|
||||
let groupNudgeMessage = null;
|
||||
const noGroupNudgeTypes = ['impersonate'];
|
||||
if (selected_group && prompts.has('groupNudge') && !noGroupNudgeTypes.includes(type)) {
|
||||
groupNudgeMessage = Message.fromPrompt(prompts.get('groupNudge'));
|
||||
groupNudgeMessage = await Message.fromPromptAsync(prompts.get('groupNudge'));
|
||||
chatCompletion.reserveBudget(groupNudgeMessage);
|
||||
}
|
||||
|
||||
@ -693,12 +693,12 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul
|
||||
};
|
||||
const continuePrompt = new Prompt(promptObject);
|
||||
const preparedPrompt = promptManager.preparePrompt(continuePrompt);
|
||||
continueMessage = Message.fromPrompt(preparedPrompt);
|
||||
continueMessage = await Message.fromPromptAsync(preparedPrompt);
|
||||
chatCompletion.reserveBudget(continueMessage);
|
||||
}
|
||||
|
||||
const lastChatPrompt = messages[messages.length - 1];
|
||||
const message = new Message('user', oai_settings.send_if_empty, 'emptyUserMessageReplacement');
|
||||
const message = await Message.createAsync('user', oai_settings.send_if_empty, 'emptyUserMessageReplacement');
|
||||
if (lastChatPrompt && lastChatPrompt.role === 'assistant' && oai_settings.send_if_empty && chatCompletion.canAfford(message)) {
|
||||
chatCompletion.insert(message, 'chatHistory');
|
||||
}
|
||||
@ -715,11 +715,11 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul
|
||||
// We do not want to mutate the prompt
|
||||
const prompt = new Prompt(chatPrompt);
|
||||
prompt.identifier = `chatHistory-${messages.length - index}`;
|
||||
const chatMessage = Message.fromPrompt(promptManager.preparePrompt(prompt));
|
||||
const chatMessage = await Message.fromPromptAsync(promptManager.preparePrompt(prompt));
|
||||
|
||||
if (promptManager.serviceSettings.names_behavior === character_names_behavior.COMPLETION && prompt.name) {
|
||||
const messageName = promptManager.isValidName(prompt.name) ? prompt.name : promptManager.sanitizeName(prompt.name);
|
||||
chatMessage.setName(messageName);
|
||||
await chatMessage.setName(messageName);
|
||||
}
|
||||
|
||||
if (imageInlining && chatPrompt.image) {
|
||||
@ -729,9 +729,9 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul
|
||||
if (canUseTools && Array.isArray(chatPrompt.invocations)) {
|
||||
/** @type {import('./tool-calling.js').ToolInvocation[]} */
|
||||
const invocations = chatPrompt.invocations;
|
||||
const toolCallMessage = new Message(chatMessage.role, undefined, 'toolCall-' + chatMessage.identifier);
|
||||
const toolResultMessages = invocations.slice().reverse().map((invocation) => new Message('tool', invocation.result || '[No content]', invocation.id));
|
||||
toolCallMessage.setToolCalls(invocations);
|
||||
const toolCallMessage = await Message.createAsync(chatMessage.role, undefined, 'toolCall-' + chatMessage.identifier);
|
||||
const toolResultMessages = await Promise.all(invocations.slice().reverse().map((invocation) => Message.createAsync('tool', invocation.result || '[No content]', invocation.id)));
|
||||
await toolCallMessage.setToolCalls(invocations);
|
||||
if (chatCompletion.canAffordAll([toolCallMessage, ...toolResultMessages])) {
|
||||
for (const resultMessage of toolResultMessages) {
|
||||
chatCompletion.insertAtStart(resultMessage, 'chatHistory');
|
||||
@ -748,7 +748,8 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul
|
||||
if (type === 'continue' && oai_settings.continue_prefill && chatPrompt === firstNonInjected) {
|
||||
// in case we are using continue_prefill and the latest message is an assistant message, we want to prepend the users assistant prefill on the message
|
||||
if (chatPrompt.role === 'assistant') {
|
||||
const collection = new MessageCollection('continuePrefill', new Message(chatMessage.role, substituteParams(oai_settings.assistant_prefill + '\n\n') + chatMessage.content, chatMessage.identifier));
|
||||
const continueMessage = await Message.createAsync(chatMessage.role, substituteParams(oai_settings.assistant_prefill + '\n\n') + chatMessage.content, chatMessage.identifier);
|
||||
const collection = new MessageCollection('continuePrefill', );
|
||||
chatCompletion.add(collection, -1);
|
||||
continue;
|
||||
}
|
||||
@ -787,15 +788,16 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul
|
||||
* @param {ChatCompletion} chatCompletion - An instance of ChatCompletion class that will be populated with the prompts.
|
||||
* @param {Object[]} messageExamples - Array containing all message examples.
|
||||
*/
|
||||
function populateDialogueExamples(prompts, chatCompletion, messageExamples) {
|
||||
async function populateDialogueExamples(prompts, chatCompletion, messageExamples) {
|
||||
if (!prompts.has('dialogueExamples')) {
|
||||
return;
|
||||
}
|
||||
|
||||
chatCompletion.add(new MessageCollection('dialogueExamples'), prompts.index('dialogueExamples'));
|
||||
if (Array.isArray(messageExamples) && messageExamples.length) {
|
||||
const newExampleChat = new Message('system', substituteParams(oai_settings.new_example_chat_prompt), 'newChat');
|
||||
[...messageExamples].forEach((dialogue, dialogueIndex) => {
|
||||
const newExampleChat = await Message.createAsync('system', substituteParams(oai_settings.new_example_chat_prompt), 'newChat');
|
||||
for (const dialogue of [...messageExamples]) {
|
||||
const dialogueIndex = messageExamples.indexOf(dialogue);
|
||||
let examplesAdded = 0;
|
||||
|
||||
if (chatCompletion.canAfford(newExampleChat)) chatCompletion.insert(newExampleChat, 'dialogueExamples');
|
||||
@ -806,8 +808,8 @@ function populateDialogueExamples(prompts, chatCompletion, messageExamples) {
|
||||
const content = prompt.content || '';
|
||||
const identifier = `dialogueExamples ${dialogueIndex}-${promptIndex}`;
|
||||
|
||||
const chatMessage = new Message(role, content, identifier);
|
||||
chatMessage.setName(prompt.name);
|
||||
const chatMessage = await Message.createAsync(role, content, identifier);
|
||||
await chatMessage.setName(prompt.name);
|
||||
if (!chatCompletion.canAfford(chatMessage)) {
|
||||
break;
|
||||
}
|
||||
@ -818,7 +820,7 @@ function populateDialogueExamples(prompts, chatCompletion, messageExamples) {
|
||||
if (0 === examplesAdded) {
|
||||
chatCompletion.removeLastFrom('dialogueExamples');
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -873,7 +875,7 @@ function getPromptRole(role) {
|
||||
*/
|
||||
async function populateChatCompletion(prompts, chatCompletion, { bias, quietPrompt, quietImage, type, cyclePrompt, messages, messageExamples }) {
|
||||
// Helper function for preparing a prompt, that already exists within the prompt collection, for completion
|
||||
const addToChatCompletion = (source, target = null) => {
|
||||
const addToChatCompletion = async (source, target = null) => {
|
||||
// We need the prompts array to determine a position for the source.
|
||||
if (false === prompts.has(source)) return;
|
||||
|
||||
@ -891,30 +893,31 @@ async function populateChatCompletion(prompts, chatCompletion, { bias, quietProm
|
||||
|
||||
const index = target ? prompts.index(target) : prompts.index(source);
|
||||
const collection = new MessageCollection(source);
|
||||
collection.add(Message.fromPrompt(prompt));
|
||||
const message = await Message.fromPromptAsync(prompt);
|
||||
collection.add(message);
|
||||
chatCompletion.add(collection, index);
|
||||
};
|
||||
|
||||
chatCompletion.reserveBudget(3); // every reply is primed with <|start|>assistant<|message|>
|
||||
// Character and world information
|
||||
addToChatCompletion('worldInfoBefore');
|
||||
addToChatCompletion('main');
|
||||
addToChatCompletion('worldInfoAfter');
|
||||
addToChatCompletion('charDescription');
|
||||
addToChatCompletion('charPersonality');
|
||||
addToChatCompletion('scenario');
|
||||
addToChatCompletion('personaDescription');
|
||||
await addToChatCompletion('worldInfoBefore');
|
||||
await addToChatCompletion('main');
|
||||
await addToChatCompletion('worldInfoAfter');
|
||||
await addToChatCompletion('charDescription');
|
||||
await addToChatCompletion('charPersonality');
|
||||
await addToChatCompletion('scenario');
|
||||
await addToChatCompletion('personaDescription');
|
||||
|
||||
// Collection of control prompts that will always be positioned last
|
||||
chatCompletion.setOverriddenPrompts(prompts.overriddenPrompts);
|
||||
const controlPrompts = new MessageCollection('controlPrompts');
|
||||
|
||||
const impersonateMessage = Message.fromPrompt(prompts.get('impersonate')) ?? null;
|
||||
const impersonateMessage = await Message.fromPromptAsync(prompts.get('impersonate')) ?? null;
|
||||
if (type === 'impersonate') controlPrompts.add(impersonateMessage);
|
||||
|
||||
// Add quiet prompt to control prompts
|
||||
// This should always be last, even in control prompts. Add all further control prompts BEFORE this prompt
|
||||
const quietPromptMessage = Message.fromPrompt(prompts.get('quietPrompt')) ?? null;
|
||||
const quietPromptMessage = await Message.fromPromptAsync(prompts.get('quietPrompt')) ?? null;
|
||||
if (quietPromptMessage && quietPromptMessage.content) {
|
||||
if (isImageInliningSupported() && quietImage) {
|
||||
await quietPromptMessage.addImage(quietImage);
|
||||
@ -940,20 +943,23 @@ async function populateChatCompletion(prompts, chatCompletion, { bias, quietProm
|
||||
return acc;
|
||||
}, []);
|
||||
|
||||
[...systemPrompts, ...userRelativePrompts].forEach(identifier => addToChatCompletion(identifier));
|
||||
for (const identifier of [...systemPrompts, ...userRelativePrompts]) {
|
||||
await addToChatCompletion(identifier);
|
||||
}
|
||||
|
||||
// Add enhance definition instruction
|
||||
if (prompts.has('enhanceDefinitions')) addToChatCompletion('enhanceDefinitions');
|
||||
if (prompts.has('enhanceDefinitions')) await addToChatCompletion('enhanceDefinitions');
|
||||
|
||||
// Bias
|
||||
if (bias && bias.trim().length) addToChatCompletion('bias');
|
||||
if (bias && bias.trim().length) await addToChatCompletion('bias');
|
||||
|
||||
// Tavern Extras - Summary
|
||||
if (prompts.has('summary')) {
|
||||
const summary = prompts.get('summary');
|
||||
|
||||
if (summary.position) {
|
||||
chatCompletion.insert(Message.fromPrompt(summary), 'main', summary.position);
|
||||
const message = await Message.fromPromptAsync(summary);
|
||||
chatCompletion.insert(message, 'main', summary.position);
|
||||
}
|
||||
}
|
||||
|
||||
@ -962,7 +968,8 @@ async function populateChatCompletion(prompts, chatCompletion, { bias, quietProm
|
||||
const authorsNote = prompts.get('authorsNote');
|
||||
|
||||
if (authorsNote.position) {
|
||||
chatCompletion.insert(Message.fromPrompt(authorsNote), 'main', authorsNote.position);
|
||||
const message = await Message.fromPromptAsync(authorsNote);
|
||||
chatCompletion.insert(message, 'main', authorsNote.position);
|
||||
}
|
||||
}
|
||||
|
||||
@ -971,7 +978,8 @@ async function populateChatCompletion(prompts, chatCompletion, { bias, quietProm
|
||||
const vectorsMemory = prompts.get('vectorsMemory');
|
||||
|
||||
if (vectorsMemory.position) {
|
||||
chatCompletion.insert(Message.fromPrompt(vectorsMemory), 'main', vectorsMemory.position);
|
||||
const message = await Message.fromPromptAsync(vectorsMemory);
|
||||
chatCompletion.insert(message, 'main', vectorsMemory.position);
|
||||
}
|
||||
}
|
||||
|
||||
@ -980,7 +988,8 @@ async function populateChatCompletion(prompts, chatCompletion, { bias, quietProm
|
||||
const vectorsDataBank = prompts.get('vectorsDataBank');
|
||||
|
||||
if (vectorsDataBank.position) {
|
||||
chatCompletion.insert(Message.fromPrompt(vectorsDataBank), 'main', vectorsDataBank.position);
|
||||
const message = await Message.fromPromptAsync(vectorsDataBank);
|
||||
chatCompletion.insert(message, 'main', vectorsDataBank.position);
|
||||
}
|
||||
}
|
||||
|
||||
@ -989,13 +998,15 @@ async function populateChatCompletion(prompts, chatCompletion, { bias, quietProm
|
||||
const smartContext = prompts.get('smartContext');
|
||||
|
||||
if (smartContext.position) {
|
||||
chatCompletion.insert(Message.fromPrompt(smartContext), 'main', smartContext.position);
|
||||
const message = await Message.fromPromptAsync(smartContext);
|
||||
chatCompletion.insert(message, 'main', smartContext.position);
|
||||
}
|
||||
}
|
||||
|
||||
// Other relative extension prompts
|
||||
for (const prompt of prompts.collection.filter(p => p.extension && p.position)) {
|
||||
chatCompletion.insert(Message.fromPrompt(prompt), 'main', prompt.position);
|
||||
const message = await Message.fromPromptAsync(prompt);
|
||||
chatCompletion.insert(message, 'main', prompt.position);
|
||||
}
|
||||
|
||||
// Pre-allocation of tokens for tool data
|
||||
@ -1003,7 +1014,7 @@ async function populateChatCompletion(prompts, chatCompletion, { bias, quietProm
|
||||
const toolData = {};
|
||||
await ToolManager.registerFunctionToolsOpenAI(toolData);
|
||||
const toolMessage = [{ role: 'user', content: JSON.stringify(toolData) }];
|
||||
const toolTokens = tokenHandler.count(toolMessage);
|
||||
const toolTokens = await tokenHandler.countAsync(toolMessage);
|
||||
chatCompletion.reserveBudget(toolTokens);
|
||||
}
|
||||
|
||||
@ -1012,11 +1023,11 @@ async function populateChatCompletion(prompts, chatCompletion, { bias, quietProm
|
||||
|
||||
// Decide whether dialogue examples should always be added
|
||||
if (power_user.pin_examples) {
|
||||
populateDialogueExamples(prompts, chatCompletion, messageExamples);
|
||||
await populateDialogueExamples(prompts, chatCompletion, messageExamples);
|
||||
await populateChatHistory(messages, prompts, chatCompletion, type, cyclePrompt);
|
||||
} else {
|
||||
await populateChatHistory(messages, prompts, chatCompletion, type, cyclePrompt);
|
||||
populateDialogueExamples(prompts, chatCompletion, messageExamples);
|
||||
await populateDialogueExamples(prompts, chatCompletion, messageExamples);
|
||||
}
|
||||
|
||||
chatCompletion.freeBudget(controlPrompts);
|
||||
@ -1281,7 +1292,7 @@ export async function prepareOpenAIMessages({
|
||||
promptManager.setChatCompletion(chatCompletion);
|
||||
|
||||
if (oai_settings.squash_system_messages && dryRun == false) {
|
||||
chatCompletion.squashSystemMessages();
|
||||
await chatCompletion.squashSystemMessages();
|
||||
}
|
||||
|
||||
// All information is up-to-date, render.
|
||||
@ -2127,8 +2138,11 @@ async function calculateLogitBias() {
|
||||
}
|
||||
|
||||
class TokenHandler {
|
||||
constructor(countTokenFn) {
|
||||
this.countTokenFn = countTokenFn;
|
||||
/**
|
||||
* @param {(messages: object[] | object, full?: boolean) => Promise<number>} countTokenAsyncFn Function to count tokens
|
||||
*/
|
||||
constructor(countTokenAsyncFn) {
|
||||
this.countTokenAsyncFn = countTokenAsyncFn;
|
||||
this.counts = {
|
||||
'start_chat': 0,
|
||||
'prompt': 0,
|
||||
@ -2157,8 +2171,15 @@ class TokenHandler {
|
||||
this.counts[type] -= value;
|
||||
}
|
||||
|
||||
count(messages, full, type) {
|
||||
const token_count = this.countTokenFn(messages, full);
|
||||
/**
|
||||
* Count tokens for a message or messages.
|
||||
* @param {object|any[]} messages Messages to count tokens for
|
||||
* @param {boolean} [full] Count full tokens
|
||||
* @param {string} [type] Identifier for the token count
|
||||
* @returns {Promise<number>} The token count
|
||||
*/
|
||||
async countAsync(messages, full, type) {
|
||||
const token_count = await this.countTokenAsyncFn(messages, full);
|
||||
this.counts[type] += token_count;
|
||||
|
||||
return token_count;
|
||||
@ -2178,7 +2199,7 @@ class TokenHandler {
|
||||
}
|
||||
|
||||
|
||||
const tokenHandler = new TokenHandler(countTokensOpenAI);
|
||||
const tokenHandler = new TokenHandler(countTokensOpenAIAsync);
|
||||
|
||||
// Thrown by ChatCompletion when a requested prompt couldn't be found.
|
||||
class IdentifierNotFoundError extends Error {
|
||||
@ -2228,6 +2249,7 @@ class Message {
|
||||
* @param {string} role - The role of the entity creating the message.
|
||||
* @param {string} content - The actual content of the message.
|
||||
* @param {string} identifier - A unique identifier for the message.
|
||||
* @private Don't use this constructor directly. Use createAsync instead.
|
||||
*/
|
||||
constructor(role, content, identifier) {
|
||||
this.identifier = identifier;
|
||||
@ -2239,18 +2261,32 @@ class Message {
|
||||
this.role = 'system';
|
||||
}
|
||||
|
||||
if (typeof this.content === 'string' && this.content.length > 0) {
|
||||
this.tokens = tokenHandler.count({ role: this.role, content: this.content });
|
||||
} else {
|
||||
this.tokens = 0;
|
||||
this.tokens = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new Message instance.
|
||||
* @param {string} role
|
||||
* @param {string} content
|
||||
* @param {string} identifier
|
||||
* @returns {Promise<Message>} Message instance
|
||||
*/
|
||||
static async createAsync(role, content, identifier) {
|
||||
const message = new Message(role, content, identifier);
|
||||
|
||||
if (typeof message.content === 'string' && message.content.length > 0) {
|
||||
message.tokens = await tokenHandler.countAsync({ role: message.role, content: message.content });
|
||||
}
|
||||
|
||||
return message;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reconstruct the message from a tool invocation.
|
||||
* @param {import('./tool-calling.js').ToolInvocation[]} invocations
|
||||
* @param {import('./tool-calling.js').ToolInvocation[]} invocations - The tool invocations to reconstruct the message from.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
setToolCalls(invocations) {
|
||||
async setToolCalls(invocations) {
|
||||
this.tool_calls = invocations.map(i => ({
|
||||
id: i.id,
|
||||
type: 'function',
|
||||
@ -2259,12 +2295,17 @@ class Message {
|
||||
name: i.name,
|
||||
},
|
||||
}));
|
||||
this.tokens = tokenHandler.count({ role: this.role, tool_calls: JSON.stringify(this.tool_calls) });
|
||||
this.tokens = await tokenHandler.countAsync({ role: this.role, tool_calls: JSON.stringify(this.tool_calls) });
|
||||
}
|
||||
|
||||
setName(name) {
|
||||
/**
|
||||
* Add a name to the message.
|
||||
* @param {string} name Name to set for the message.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async setName(name) {
|
||||
this.name = name;
|
||||
this.tokens = tokenHandler.count({ role: this.role, content: this.content, name: this.name });
|
||||
this.tokens = await tokenHandler.countAsync({ role: this.role, content: this.content, name: this.name });
|
||||
}
|
||||
|
||||
async addImage(image) {
|
||||
@ -2356,13 +2397,13 @@ class Message {
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new Message instance from a prompt.
|
||||
* Create a new Message instance from a prompt asynchronously.
|
||||
* @static
|
||||
* @param {Object} prompt - The prompt object.
|
||||
* @returns {Message} A new instance of Message.
|
||||
* @returns {Promise<Message>} A new instance of Message.
|
||||
*/
|
||||
static fromPrompt(prompt) {
|
||||
return new Message(prompt.role, prompt.content, prompt.identifier);
|
||||
static async fromPromptAsync(prompt) {
|
||||
return Message.createAsync(prompt.role, prompt.content, prompt.identifier);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -2488,8 +2529,9 @@ export class ChatCompletion {
|
||||
|
||||
/**
|
||||
* Combines consecutive system messages into one if they have no name attached.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
squashSystemMessages() {
|
||||
async squashSystemMessages() {
|
||||
const excludeList = ['newMainChat', 'newChat', 'groupNudge'];
|
||||
this.messages.collection = this.messages.flatten();
|
||||
|
||||
@ -2509,7 +2551,7 @@ export class ChatCompletion {
|
||||
if (shouldSquash(message)) {
|
||||
if (lastMessage && shouldSquash(lastMessage)) {
|
||||
lastMessage.content += '\n' + message.content;
|
||||
lastMessage.tokens = tokenHandler.count({ role: lastMessage.role, content: lastMessage.content });
|
||||
lastMessage.tokens = await tokenHandler.countAsync({ role: lastMessage.role, content: lastMessage.content });
|
||||
}
|
||||
else {
|
||||
squashedMessages.push(message);
|
||||
|
Loading…
x
Reference in New Issue
Block a user