diff --git a/public/script.js b/public/script.js index fac11b8f4..3312c556f 100644 --- a/public/script.js +++ b/public/script.js @@ -4408,7 +4408,7 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro if (ToolManager.isFunctionCallingSupported() && Array.isArray(streamingProcessor.toolCalls) && streamingProcessor.toolCalls.length) { const invocations = await ToolManager.checkFunctionToolCalls(streamingProcessor.toolCalls); - if (invocations.length) { + if (Array.isArray(invocations) && invocations.length) { const lastMessage = chat[chat.length - 1]; const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result); if (shouldDeleteMessage) { @@ -4457,7 +4457,7 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro if (ToolManager.isFunctionCallingSupported()) { const invocations = await ToolManager.checkFunctionToolCalls(data); - if (invocations.length) { + if (Array.isArray(invocations) && invocations.length) { ToolManager.saveFunctionToolInvocations(invocations); return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun); } diff --git a/public/scripts/openai.js b/public/scripts/openai.js index a6048dce5..8b9c06048 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -454,7 +454,8 @@ function setOpenAIMessages(chat) { if (role == 'user' && oai_settings.wrap_in_quotes) content = `"${content}"`; const name = chat[j]['name']; const image = chat[j]?.extra?.image; - messages[i] = { 'role': role, 'content': content, name: name, 'image': image }; + const invocations = chat[j]?.extra?.tool_invocations; + messages[i] = { 'role': role, 'content': content, name: name, 'image': image, 'invocations': invocations }; j++; } @@ -702,6 +703,7 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul } const imageInlining = isImageInliningSupported(); + const toolCalling = ToolManager.isFunctionCallingSupported(); // Insert chat messages as long as there is budget available const chatPool = [...messages].reverse(); @@ -723,6 +725,24 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul await chatMessage.addImage(chatPrompt.image); } + if (toolCalling && Array.isArray(chatPrompt.invocations)) { + /** @type {import('./tool-calling.js').ToolInvocation[]} */ + const invocations = chatPrompt.invocations.slice().reverse(); + const toolCallMessage = new Message('assistant', undefined, 'toolCall-' + chatMessage.identifier); + toolCallMessage.setToolCalls(invocations); + if (chatCompletion.canAfford(toolCallMessage)) { + for (const invocation of invocations) { + const toolResultMessage = new Message('tool', invocation.result, invocation.id); + const canAfford = chatCompletion.canAfford(toolResultMessage) && chatCompletion.canAfford(toolCallMessage); + if (!canAfford) { + break; + } + chatCompletion.insertAtStart(toolResultMessage, 'chatHistory'); + } + chatCompletion.insertAtStart(toolCallMessage, 'chatHistory'); + } + } + if (chatCompletion.canAfford(chatMessage)) { if (type === 'continue' && oai_settings.continue_prefill && chatPrompt === firstNonInjected) { // in case we are using continue_prefill and the latest message is an assistant message, we want to prepend the users assistant prefill on the message @@ -2193,6 +2213,8 @@ class Message { content; /** @type {string} */ name; + /** @type {object} */ + tool_call = null; /** * @constructor @@ -2217,6 +2239,22 @@ class Message { } } + /** + * Reconstruct the message from a tool invocation. + * @param {import('./tool-calling.js').ToolInvocation[]} invocations + */ + setToolCalls(invocations) { + this.tool_calls = invocations.map(i => ({ + id: i.id, + type: 'function', + function: { + arguments: i.parameters, + name: i.name, + }, + })); + this.tokens = tokenHandler.count({ role: this.role, tool_calls: JSON.stringify(this.tool_calls) }); + } + setName(name) { this.name = name; this.tokens = tokenHandler.count({ role: this.role, content: this.content, name: this.name }); @@ -2564,7 +2602,7 @@ export class ChatCompletion { this.checkTokenBudget(message, message.identifier); const index = this.findMessageIndex(identifier); - if (message.content) { + if (message.content || message.tool_calls) { if ('start' === position) this.messages.collection[index].collection.unshift(message); else if ('end' === position) this.messages.collection[index].collection.push(message); else if (typeof position === 'number') this.messages.collection[index].collection.splice(position, 0, message); @@ -2633,8 +2671,14 @@ export class ChatCompletion { for (let item of this.messages.collection) { if (item instanceof MessageCollection) { chat.push(...item.getChat()); - } else if (item instanceof Message && item.content) { - const message = { role: item.role, content: item.content, ...(item.name ? { name: item.name } : {}) }; + } else if (item instanceof Message && (item.content || item.tool_calls)) { + const message = { + role: item.role, + content: item.content, + ...(item.name ? { name: item.name } : {}), + ...(item.tool_calls ? { tool_calls: item.tool_calls } : {}), + ...(item.role === 'tool' ? { tool_call_id: item.identifier } : {}), + }; chat.push(message); } else { this.log(`Skipping invalid or empty message in collection: ${JSON.stringify(item)}`); diff --git a/public/scripts/tool-calling.js b/public/scripts/tool-calling.js index c2f6c209f..c88528dd7 100644 --- a/public/scripts/tool-calling.js +++ b/public/scripts/tool-calling.js @@ -307,7 +307,7 @@ export class ToolManager { if (oaiCompat.includes(oai_settings.chat_completion_source)) { if (!Array.isArray(toolCalls)) { - return; + return []; } for (const toolCall of toolCalls) { @@ -363,7 +363,7 @@ export class ToolManager { /** * Saves function tool invocations to the last user chat message extra metadata. - * @param {ToolInvocation[]} invocations + * @param {ToolInvocation[]} invocations Successful tool invocations */ static saveFunctionToolInvocations(invocations) { for (let index = chat.length - 1; index >= 0; index--) { @@ -373,7 +373,6 @@ export class ToolManager { message.extra = {}; } message.extra.tool_invocations = invocations; - debugger; break; } }