Save tool calls to visible chats.

This commit is contained in:
Cohee
2024-10-02 22:17:27 +03:00
parent 3335dbf1a7
commit 0f8c1fa95d
3 changed files with 71 additions and 44 deletions

View File

@ -3571,7 +3571,9 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
} }
// Collect messages with usable content // Collect messages with usable content
let coreChat = chat.filter(x => !x.is_system); const canUseTools = ToolManager.isToolCallingSupported();
const canPerformToolCalls = !dryRun && ToolManager.canPerformToolCalls(type);
let coreChat = chat.filter(x => !x.is_system || (canUseTools && Array.isArray(x.extra?.tool_invocations)));
if (type === 'swipe') { if (type === 'swipe') {
coreChat.pop(); coreChat.pop();
} }
@ -4406,8 +4408,8 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
getMessage = continue_mag + getMessage; getMessage = continue_mag + getMessage;
} }
if (ToolManager.isFunctionCallingSupported() && Array.isArray(streamingProcessor.toolCalls) && streamingProcessor.toolCalls.length) { if (canPerformToolCalls && Array.isArray(streamingProcessor.toolCalls) && streamingProcessor.toolCalls.length) {
const invocations = await ToolManager.checkFunctionToolCalls(streamingProcessor.toolCalls); const invocations = await ToolManager.invokeFunctionTools(streamingProcessor.toolCalls);
if (Array.isArray(invocations) && invocations.length) { if (Array.isArray(invocations) && invocations.length) {
const lastMessage = chat[chat.length - 1]; const lastMessage = chat[chat.length - 1];
const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result); const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result);
@ -4455,14 +4457,6 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
throw new Error(data?.response); throw new Error(data?.response);
} }
if (ToolManager.isFunctionCallingSupported()) {
const invocations = await ToolManager.checkFunctionToolCalls(data);
if (Array.isArray(invocations) && invocations.length) {
ToolManager.saveFunctionToolInvocations(invocations);
return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun);
}
}
//const getData = await response.json(); //const getData = await response.json();
let getMessage = extractMessageFromData(data); let getMessage = extractMessageFromData(data);
let title = extractTitleFromData(data); let title = extractTitleFromData(data);
@ -4502,6 +4496,16 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
parseAndSaveLogprobs(data, continue_mag); parseAndSaveLogprobs(data, continue_mag);
} }
if (canPerformToolCalls) {
const invocations = await ToolManager.invokeFunctionTools(data);
if (Array.isArray(invocations) && invocations.length) {
const shouldDeleteMessage = ['', '...'].includes(getMessage);
shouldDeleteMessage && await deleteLastMessage();
ToolManager.saveFunctionToolInvocations(invocations);
return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun);
}
}
if (type !== 'quiet') { if (type !== 'quiet') {
playMessageSound(); playMessageSound();
} }

View File

@ -703,7 +703,7 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul
} }
const imageInlining = isImageInliningSupported(); const imageInlining = isImageInliningSupported();
const toolCalling = ToolManager.isFunctionCallingSupported(); const canUseTools = ToolManager.isToolCallingSupported();
// Insert chat messages as long as there is budget available // Insert chat messages as long as there is budget available
const chatPool = [...messages].reverse(); const chatPool = [...messages].reverse();
@ -725,10 +725,10 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul
await chatMessage.addImage(chatPrompt.image); await chatMessage.addImage(chatPrompt.image);
} }
if (toolCalling && Array.isArray(chatPrompt.invocations)) { if (canUseTools && Array.isArray(chatPrompt.invocations)) {
/** @type {import('./tool-calling.js').ToolInvocation[]} */ /** @type {import('./tool-calling.js').ToolInvocation[]} */
const invocations = chatPrompt.invocations; const invocations = chatPrompt.invocations;
const toolCallMessage = new Message('assistant', undefined, 'toolCall-' + chatMessage.identifier); const toolCallMessage = new Message(chatMessage.role, undefined, 'toolCall-' + chatMessage.identifier);
toolCallMessage.setToolCalls(invocations); toolCallMessage.setToolCalls(invocations);
if (chatCompletion.canAfford(toolCallMessage)) { if (chatCompletion.canAfford(toolCallMessage)) {
chatCompletion.reserveBudget(toolCallMessage); chatCompletion.reserveBudget(toolCallMessage);
@ -1285,7 +1285,7 @@ export async function prepareOpenAIMessages({
const eventData = { chat, dryRun }; const eventData = { chat, dryRun };
await eventSource.emit(event_types.CHAT_COMPLETION_PROMPT_READY, eventData); await eventSource.emit(event_types.CHAT_COMPLETION_PROMPT_READY, eventData);
openai_messages_count = chat.filter(x => x?.role === 'user' || x?.role === 'assistant')?.length || 0; openai_messages_count = chat.filter(x => !x?.tool_calls && (x?.role === 'user' || x?.role === 'assistant'))?.length || 0;
return [chat, promptManager.tokenHandler.counts]; return [chat, promptManager.tokenHandler.counts];
} }
@ -1886,7 +1886,7 @@ async function sendOpenAIRequest(type, messages, signal) {
generate_data['seed'] = oai_settings.seed; generate_data['seed'] = oai_settings.seed;
} }
if (!canMultiSwipe && ToolManager.isFunctionCallingSupported()) { if (!canMultiSwipe && ToolManager.canPerformToolCalls(type)) {
await ToolManager.registerFunctionToolsOpenAI(generate_data); await ToolManager.registerFunctionToolsOpenAI(generate_data);
} }
@ -2393,13 +2393,20 @@ class MessageCollection {
} }
/** /**
* Get chat in the format of {role, name, content}. * Get chat in the format of {role, name, content, tool_calls}.
* @returns {Array} Array of objects with role, name, and content properties. * @returns {Array} Array of objects with role, name, and content properties.
*/ */
getChat() { getChat() {
return this.collection.reduce((acc, message) => { return this.collection.reduce((acc, message) => {
const name = message.name; if (message.content || message.tool_calls) {
if (message.content) acc.push({ role: message.role, ...(name && { name }), content: message.content }); acc.push({
role: message.role,
content: message.content,
...(message.name && { name: message.name }),
...(message.tool_calls && { tool_calls: message.tool_calls }),
...(message.role === 'tool' && { tool_call_id: message.identifier }),
});
}
return acc; return acc;
}, []); }, []);
} }

View File

@ -1,4 +1,4 @@
import { chat, main_api } from '../script.js'; import { addOneMessage, chat, main_api, system_avatar, systemUserName } from '../script.js';
import { chat_completion_sources, oai_settings } from './openai.js'; import { chat_completion_sources, oai_settings } from './openai.js';
/** /**
@ -243,12 +243,12 @@ export class ToolManager {
} }
} }
static isFunctionCallingSupported() { /**
if (main_api !== 'openai') { * Checks if tool calling is supported for the current settings and generation type.
return false; * @returns {boolean} Whether tool calling is supported for the given type
} */
static isToolCallingSupported() {
if (!oai_settings.function_calling) { if (main_api !== 'openai' || !oai_settings.function_calling) {
return false; return false;
} }
@ -264,6 +264,22 @@ export class ToolManager {
return supportedSources.includes(oai_settings.chat_completion_source); return supportedSources.includes(oai_settings.chat_completion_source);
} }
/**
* Checks if tool calls can be performed for the current settings and generation type.
* @param {string} type Generation type
* @returns {boolean} Whether tool calls can be performed for the given type
*/
static canPerformToolCalls(type) {
const noToolCallTypes = ['swipe', 'impersonate', 'quiet', 'continue'];
const isSupported = ToolManager.isToolCallingSupported();
return isSupported && !noToolCallTypes.includes(type);
}
/**
* Utility function to get tool calls from the response data.
* @param {any} data Response data
* @returns {any[]} Tool calls from the response data
*/
static #getToolCallsFromData(data) { static #getToolCallsFromData(data) {
// Parsed tool calls from streaming data // Parsed tool calls from streaming data
if (Array.isArray(data) && data.length > 0) { if (Array.isArray(data) && data.length > 0) {
@ -290,15 +306,11 @@ export class ToolManager {
* @param {any} data Reply data * @param {any} data Reply data
* @returns {Promise<ToolInvocation[]>} Successful tool invocations * @returns {Promise<ToolInvocation[]>} Successful tool invocations
*/ */
static async checkFunctionToolCalls(data) { static async invokeFunctionTools(data) {
if (!ToolManager.isFunctionCallingSupported()) {
return [];
}
/** @type {ToolInvocation[]} */ /** @type {ToolInvocation[]} */
const invocations = []; const invocations = [];
const toolCalls = ToolManager.#getToolCallsFromData(data); const toolCalls = ToolManager.#getToolCallsFromData(data);
const oaiCompat = [ const oaiCompatibleSources = [
chat_completion_sources.OPENAI, chat_completion_sources.OPENAI,
chat_completion_sources.CUSTOM, chat_completion_sources.CUSTOM,
chat_completion_sources.MISTRALAI, chat_completion_sources.MISTRALAI,
@ -306,7 +318,7 @@ export class ToolManager {
chat_completion_sources.GROQ, chat_completion_sources.GROQ,
]; ];
if (oaiCompat.includes(oai_settings.chat_completion_source)) { if (oaiCompatibleSources.includes(oai_settings.chat_completion_source)) {
if (!Array.isArray(toolCalls)) { if (!Array.isArray(toolCalls)) {
return []; return [];
} }
@ -323,7 +335,7 @@ export class ToolManager {
toastr.info('Invoking function tool: ' + name); toastr.info('Invoking function tool: ' + name);
const result = await ToolManager.invokeFunctionTool(name, parameters); const result = await ToolManager.invokeFunctionTool(name, parameters);
toastr.info('Function tool result: ' + result); console.log('Function tool result:', result);
// Save a successful invocation // Save a successful invocation
if (result) { if (result) {
@ -367,15 +379,19 @@ export class ToolManager {
* @param {ToolInvocation[]} invocations Successful tool invocations * @param {ToolInvocation[]} invocations Successful tool invocations
*/ */
static saveFunctionToolInvocations(invocations) { static saveFunctionToolInvocations(invocations) {
for (let index = chat.length - 1; index >= 0; index--) { const toolNames = invocations.map(i => i.name).join(', ');
const message = chat[index]; const message = {
if (message.is_user) { name: systemUserName,
if (!message.extra || typeof message.extra !== 'object') { force_avatar: system_avatar,
message.extra = {}; is_system: true,
} is_user: false,
message.extra.tool_invocations = invocations; mes: `Performed tool calls: ${toolNames}`,
break; extra: {
} isSmallSys: true,
} tool_invocations: invocations,
},
};
chat.push(message);
addOneMessage(message);
} }
} }