Save tool calls to visible chats.

This commit is contained in:
Cohee 2024-10-02 22:17:27 +03:00
parent 3335dbf1a7
commit 0f8c1fa95d
3 changed files with 71 additions and 44 deletions

View File

@ -3571,7 +3571,9 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
}
// Collect messages with usable content
let coreChat = chat.filter(x => !x.is_system);
const canUseTools = ToolManager.isToolCallingSupported();
const canPerformToolCalls = !dryRun && ToolManager.canPerformToolCalls(type);
let coreChat = chat.filter(x => !x.is_system || (canUseTools && Array.isArray(x.extra?.tool_invocations)));
if (type === 'swipe') {
coreChat.pop();
}
@ -4406,8 +4408,8 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
getMessage = continue_mag + getMessage;
}
if (ToolManager.isFunctionCallingSupported() && Array.isArray(streamingProcessor.toolCalls) && streamingProcessor.toolCalls.length) {
const invocations = await ToolManager.checkFunctionToolCalls(streamingProcessor.toolCalls);
if (canPerformToolCalls && Array.isArray(streamingProcessor.toolCalls) && streamingProcessor.toolCalls.length) {
const invocations = await ToolManager.invokeFunctionTools(streamingProcessor.toolCalls);
if (Array.isArray(invocations) && invocations.length) {
const lastMessage = chat[chat.length - 1];
const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result);
@ -4455,14 +4457,6 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
throw new Error(data?.response);
}
if (ToolManager.isFunctionCallingSupported()) {
const invocations = await ToolManager.checkFunctionToolCalls(data);
if (Array.isArray(invocations) && invocations.length) {
ToolManager.saveFunctionToolInvocations(invocations);
return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun);
}
}
//const getData = await response.json();
let getMessage = extractMessageFromData(data);
let title = extractTitleFromData(data);
@ -4502,6 +4496,16 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
parseAndSaveLogprobs(data, continue_mag);
}
if (canPerformToolCalls) {
const invocations = await ToolManager.invokeFunctionTools(data);
if (Array.isArray(invocations) && invocations.length) {
const shouldDeleteMessage = ['', '...'].includes(getMessage);
shouldDeleteMessage && await deleteLastMessage();
ToolManager.saveFunctionToolInvocations(invocations);
return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun);
}
}
if (type !== 'quiet') {
playMessageSound();
}

View File

@ -703,7 +703,7 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul
}
const imageInlining = isImageInliningSupported();
const toolCalling = ToolManager.isFunctionCallingSupported();
const canUseTools = ToolManager.isToolCallingSupported();
// Insert chat messages as long as there is budget available
const chatPool = [...messages].reverse();
@ -725,10 +725,10 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul
await chatMessage.addImage(chatPrompt.image);
}
if (toolCalling && Array.isArray(chatPrompt.invocations)) {
if (canUseTools && Array.isArray(chatPrompt.invocations)) {
/** @type {import('./tool-calling.js').ToolInvocation[]} */
const invocations = chatPrompt.invocations;
const toolCallMessage = new Message('assistant', undefined, 'toolCall-' + chatMessage.identifier);
const toolCallMessage = new Message(chatMessage.role, undefined, 'toolCall-' + chatMessage.identifier);
toolCallMessage.setToolCalls(invocations);
if (chatCompletion.canAfford(toolCallMessage)) {
chatCompletion.reserveBudget(toolCallMessage);
@ -1285,7 +1285,7 @@ export async function prepareOpenAIMessages({
const eventData = { chat, dryRun };
await eventSource.emit(event_types.CHAT_COMPLETION_PROMPT_READY, eventData);
openai_messages_count = chat.filter(x => x?.role === 'user' || x?.role === 'assistant')?.length || 0;
openai_messages_count = chat.filter(x => !x?.tool_calls && (x?.role === 'user' || x?.role === 'assistant'))?.length || 0;
return [chat, promptManager.tokenHandler.counts];
}
@ -1886,7 +1886,7 @@ async function sendOpenAIRequest(type, messages, signal) {
generate_data['seed'] = oai_settings.seed;
}
if (!canMultiSwipe && ToolManager.isFunctionCallingSupported()) {
if (!canMultiSwipe && ToolManager.canPerformToolCalls(type)) {
await ToolManager.registerFunctionToolsOpenAI(generate_data);
}
@ -2393,13 +2393,20 @@ class MessageCollection {
}
/**
* Get chat in the format of {role, name, content}.
* Get chat in the format of {role, name, content, tool_calls}.
* @returns {Array} Array of objects with role, name, and content properties.
*/
getChat() {
return this.collection.reduce((acc, message) => {
const name = message.name;
if (message.content) acc.push({ role: message.role, ...(name && { name }), content: message.content });
if (message.content || message.tool_calls) {
acc.push({
role: message.role,
content: message.content,
...(message.name && { name: message.name }),
...(message.tool_calls && { tool_calls: message.tool_calls }),
...(message.role === 'tool' && { tool_call_id: message.identifier }),
});
}
return acc;
}, []);
}

View File

@ -1,4 +1,4 @@
import { chat, main_api } from '../script.js';
import { addOneMessage, chat, main_api, system_avatar, systemUserName } from '../script.js';
import { chat_completion_sources, oai_settings } from './openai.js';
/**
@ -243,12 +243,12 @@ export class ToolManager {
}
}
static isFunctionCallingSupported() {
if (main_api !== 'openai') {
return false;
}
if (!oai_settings.function_calling) {
/**
* Checks if tool calling is supported for the current settings and generation type.
* @returns {boolean} Whether tool calling is supported for the given type
*/
static isToolCallingSupported() {
if (main_api !== 'openai' || !oai_settings.function_calling) {
return false;
}
@ -264,6 +264,22 @@ export class ToolManager {
return supportedSources.includes(oai_settings.chat_completion_source);
}
/**
* Checks if tool calls can be performed for the current settings and generation type.
* @param {string} type Generation type
* @returns {boolean} Whether tool calls can be performed for the given type
*/
static canPerformToolCalls(type) {
const noToolCallTypes = ['swipe', 'impersonate', 'quiet', 'continue'];
const isSupported = ToolManager.isToolCallingSupported();
return isSupported && !noToolCallTypes.includes(type);
}
/**
* Utility function to get tool calls from the response data.
* @param {any} data Response data
* @returns {any[]} Tool calls from the response data
*/
static #getToolCallsFromData(data) {
// Parsed tool calls from streaming data
if (Array.isArray(data) && data.length > 0) {
@ -290,15 +306,11 @@ export class ToolManager {
* @param {any} data Reply data
* @returns {Promise<ToolInvocation[]>} Successful tool invocations
*/
static async checkFunctionToolCalls(data) {
if (!ToolManager.isFunctionCallingSupported()) {
return [];
}
static async invokeFunctionTools(data) {
/** @type {ToolInvocation[]} */
const invocations = [];
const toolCalls = ToolManager.#getToolCallsFromData(data);
const oaiCompat = [
const oaiCompatibleSources = [
chat_completion_sources.OPENAI,
chat_completion_sources.CUSTOM,
chat_completion_sources.MISTRALAI,
@ -306,7 +318,7 @@ export class ToolManager {
chat_completion_sources.GROQ,
];
if (oaiCompat.includes(oai_settings.chat_completion_source)) {
if (oaiCompatibleSources.includes(oai_settings.chat_completion_source)) {
if (!Array.isArray(toolCalls)) {
return [];
}
@ -323,7 +335,7 @@ export class ToolManager {
toastr.info('Invoking function tool: ' + name);
const result = await ToolManager.invokeFunctionTool(name, parameters);
toastr.info('Function tool result: ' + result);
console.log('Function tool result:', result);
// Save a successful invocation
if (result) {
@ -367,15 +379,19 @@ export class ToolManager {
* @param {ToolInvocation[]} invocations Successful tool invocations
*/
static saveFunctionToolInvocations(invocations) {
for (let index = chat.length - 1; index >= 0; index--) {
const message = chat[index];
if (message.is_user) {
if (!message.extra || typeof message.extra !== 'object') {
message.extra = {};
}
message.extra.tool_invocations = invocations;
break;
}
}
const toolNames = invocations.map(i => i.name).join(', ');
const message = {
name: systemUserName,
force_avatar: system_avatar,
is_system: true,
is_user: false,
mes: `Performed tool calls: ${toolNames}`,
extra: {
isSmallSys: true,
tool_invocations: invocations,
},
};
chat.push(message);
addOneMessage(message);
}
}