diff --git a/public/script.js b/public/script.js index 3312c556f..1651d786e 100644 --- a/public/script.js +++ b/public/script.js @@ -3571,7 +3571,9 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro } // Collect messages with usable content - let coreChat = chat.filter(x => !x.is_system); + const canUseTools = ToolManager.isToolCallingSupported(); + const canPerformToolCalls = !dryRun && ToolManager.canPerformToolCalls(type); + let coreChat = chat.filter(x => !x.is_system || (canUseTools && Array.isArray(x.extra?.tool_invocations))); if (type === 'swipe') { coreChat.pop(); } @@ -4406,8 +4408,8 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro getMessage = continue_mag + getMessage; } - if (ToolManager.isFunctionCallingSupported() && Array.isArray(streamingProcessor.toolCalls) && streamingProcessor.toolCalls.length) { - const invocations = await ToolManager.checkFunctionToolCalls(streamingProcessor.toolCalls); + if (canPerformToolCalls && Array.isArray(streamingProcessor.toolCalls) && streamingProcessor.toolCalls.length) { + const invocations = await ToolManager.invokeFunctionTools(streamingProcessor.toolCalls); if (Array.isArray(invocations) && invocations.length) { const lastMessage = chat[chat.length - 1]; const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result); @@ -4455,14 +4457,6 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro throw new Error(data?.response); } - if (ToolManager.isFunctionCallingSupported()) { - const invocations = await ToolManager.checkFunctionToolCalls(data); - if (Array.isArray(invocations) && invocations.length) { - ToolManager.saveFunctionToolInvocations(invocations); - return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun); - } - } - //const getData = await response.json(); let getMessage = extractMessageFromData(data); let title = extractTitleFromData(data); @@ -4502,6 +4496,16 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro parseAndSaveLogprobs(data, continue_mag); } + if (canPerformToolCalls) { + const invocations = await ToolManager.invokeFunctionTools(data); + if (Array.isArray(invocations) && invocations.length) { + const shouldDeleteMessage = ['', '...'].includes(getMessage); + shouldDeleteMessage && await deleteLastMessage(); + ToolManager.saveFunctionToolInvocations(invocations); + return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun); + } + } + if (type !== 'quiet') { playMessageSound(); } diff --git a/public/scripts/openai.js b/public/scripts/openai.js index 96821eda4..e9216eccf 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -703,7 +703,7 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul } const imageInlining = isImageInliningSupported(); - const toolCalling = ToolManager.isFunctionCallingSupported(); + const canUseTools = ToolManager.isToolCallingSupported(); // Insert chat messages as long as there is budget available const chatPool = [...messages].reverse(); @@ -725,10 +725,10 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul await chatMessage.addImage(chatPrompt.image); } - if (toolCalling && Array.isArray(chatPrompt.invocations)) { + if (canUseTools && Array.isArray(chatPrompt.invocations)) { /** @type {import('./tool-calling.js').ToolInvocation[]} */ const invocations = chatPrompt.invocations; - const toolCallMessage = new Message('assistant', undefined, 'toolCall-' + chatMessage.identifier); + const toolCallMessage = new Message(chatMessage.role, undefined, 'toolCall-' + chatMessage.identifier); toolCallMessage.setToolCalls(invocations); if (chatCompletion.canAfford(toolCallMessage)) { chatCompletion.reserveBudget(toolCallMessage); @@ -1285,7 +1285,7 @@ export async function prepareOpenAIMessages({ const eventData = { chat, dryRun }; await eventSource.emit(event_types.CHAT_COMPLETION_PROMPT_READY, eventData); - openai_messages_count = chat.filter(x => x?.role === 'user' || x?.role === 'assistant')?.length || 0; + openai_messages_count = chat.filter(x => !x?.tool_calls && (x?.role === 'user' || x?.role === 'assistant'))?.length || 0; return [chat, promptManager.tokenHandler.counts]; } @@ -1886,7 +1886,7 @@ async function sendOpenAIRequest(type, messages, signal) { generate_data['seed'] = oai_settings.seed; } - if (!canMultiSwipe && ToolManager.isFunctionCallingSupported()) { + if (!canMultiSwipe && ToolManager.canPerformToolCalls(type)) { await ToolManager.registerFunctionToolsOpenAI(generate_data); } @@ -2393,13 +2393,20 @@ class MessageCollection { } /** - * Get chat in the format of {role, name, content}. + * Get chat in the format of {role, name, content, tool_calls}. * @returns {Array} Array of objects with role, name, and content properties. */ getChat() { return this.collection.reduce((acc, message) => { - const name = message.name; - if (message.content) acc.push({ role: message.role, ...(name && { name }), content: message.content }); + if (message.content || message.tool_calls) { + acc.push({ + role: message.role, + content: message.content, + ...(message.name && { name: message.name }), + ...(message.tool_calls && { tool_calls: message.tool_calls }), + ...(message.role === 'tool' && { tool_call_id: message.identifier }), + }); + } return acc; }, []); } diff --git a/public/scripts/tool-calling.js b/public/scripts/tool-calling.js index 964093bc3..def0c078a 100644 --- a/public/scripts/tool-calling.js +++ b/public/scripts/tool-calling.js @@ -1,4 +1,4 @@ -import { chat, main_api } from '../script.js'; +import { addOneMessage, chat, main_api, system_avatar, systemUserName } from '../script.js'; import { chat_completion_sources, oai_settings } from './openai.js'; /** @@ -243,12 +243,12 @@ export class ToolManager { } } - static isFunctionCallingSupported() { - if (main_api !== 'openai') { - return false; - } - - if (!oai_settings.function_calling) { + /** + * Checks if tool calling is supported for the current settings and generation type. + * @returns {boolean} Whether tool calling is supported for the given type + */ + static isToolCallingSupported() { + if (main_api !== 'openai' || !oai_settings.function_calling) { return false; } @@ -264,6 +264,22 @@ export class ToolManager { return supportedSources.includes(oai_settings.chat_completion_source); } + /** + * Checks if tool calls can be performed for the current settings and generation type. + * @param {string} type Generation type + * @returns {boolean} Whether tool calls can be performed for the given type + */ + static canPerformToolCalls(type) { + const noToolCallTypes = ['swipe', 'impersonate', 'quiet', 'continue']; + const isSupported = ToolManager.isToolCallingSupported(); + return isSupported && !noToolCallTypes.includes(type); + } + + /** + * Utility function to get tool calls from the response data. + * @param {any} data Response data + * @returns {any[]} Tool calls from the response data + */ static #getToolCallsFromData(data) { // Parsed tool calls from streaming data if (Array.isArray(data) && data.length > 0) { @@ -290,15 +306,11 @@ export class ToolManager { * @param {any} data Reply data * @returns {Promise} Successful tool invocations */ - static async checkFunctionToolCalls(data) { - if (!ToolManager.isFunctionCallingSupported()) { - return []; - } - + static async invokeFunctionTools(data) { /** @type {ToolInvocation[]} */ const invocations = []; const toolCalls = ToolManager.#getToolCallsFromData(data); - const oaiCompat = [ + const oaiCompatibleSources = [ chat_completion_sources.OPENAI, chat_completion_sources.CUSTOM, chat_completion_sources.MISTRALAI, @@ -306,7 +318,7 @@ export class ToolManager { chat_completion_sources.GROQ, ]; - if (oaiCompat.includes(oai_settings.chat_completion_source)) { + if (oaiCompatibleSources.includes(oai_settings.chat_completion_source)) { if (!Array.isArray(toolCalls)) { return []; } @@ -323,7 +335,7 @@ export class ToolManager { toastr.info('Invoking function tool: ' + name); const result = await ToolManager.invokeFunctionTool(name, parameters); - toastr.info('Function tool result: ' + result); + console.log('Function tool result:', result); // Save a successful invocation if (result) { @@ -367,15 +379,19 @@ export class ToolManager { * @param {ToolInvocation[]} invocations Successful tool invocations */ static saveFunctionToolInvocations(invocations) { - for (let index = chat.length - 1; index >= 0; index--) { - const message = chat[index]; - if (message.is_user) { - if (!message.extra || typeof message.extra !== 'object') { - message.extra = {}; - } - message.extra.tool_invocations = invocations; - break; - } - } + const toolNames = invocations.map(i => i.name).join(', '); + const message = { + name: systemUserName, + force_avatar: system_avatar, + is_system: true, + is_user: false, + mes: `Performed tool calls: ${toolNames}`, + extra: { + isSmallSys: true, + tool_invocations: invocations, + }, + }; + chat.push(message); + addOneMessage(message); } }