Tool Calling: Implement stealth tool defintions (#3192)

* Tool Calling: Implement stealth tool defintions

* Move isStealth check up

* Always stop generation on stealth tool calls

* Image Generation: use stealth flag for tool registration

* Update stealth property description to clarify no follow-up generation will be performed

* Revert "Image Generation: use stealth flag for tool registration"

This reverts commit 8d13445c0b.
This commit is contained in:
Cohee
2024-12-19 21:17:47 +02:00
committed by GitHub
parent e83182c03b
commit 7e7b3e30c4
2 changed files with 68 additions and 9 deletions

View File

@ -4579,9 +4579,12 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
const shouldDeleteMessage = type !== 'swipe' && ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor?.result); const shouldDeleteMessage = type !== 'swipe' && ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor?.result);
hasToolCalls && shouldDeleteMessage && await deleteLastMessage(); hasToolCalls && shouldDeleteMessage && await deleteLastMessage();
const invocationResult = await ToolManager.invokeFunctionTools(streamingProcessor.toolCalls); const invocationResult = await ToolManager.invokeFunctionTools(streamingProcessor.toolCalls);
const shouldStopGeneration = (!invocationResult.invocations.length && shouldDeleteMessage) || invocationResult.stealthCalls.length;
if (hasToolCalls) { if (hasToolCalls) {
if (!invocationResult.invocations.length && shouldDeleteMessage) { if (shouldStopGeneration) {
if (Array.isArray(invocationResult.errors) && invocationResult.errors.length) {
ToolManager.showToolCallError(invocationResult.errors); ToolManager.showToolCallError(invocationResult.errors);
}
unblockGeneration(type); unblockGeneration(type);
generatedPromptCache = ''; generatedPromptCache = '';
streamingProcessor = null; streamingProcessor = null;
@ -4681,9 +4684,12 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
const shouldDeleteMessage = type !== 'swipe' && ['', '...'].includes(getMessage); const shouldDeleteMessage = type !== 'swipe' && ['', '...'].includes(getMessage);
hasToolCalls && shouldDeleteMessage && await deleteLastMessage(); hasToolCalls && shouldDeleteMessage && await deleteLastMessage();
const invocationResult = await ToolManager.invokeFunctionTools(data); const invocationResult = await ToolManager.invokeFunctionTools(data);
const shouldStopGeneration = (!invocationResult.invocations.length && shouldDeleteMessage) || invocationResult.stealthCalls.length;
if (hasToolCalls) { if (hasToolCalls) {
if (!invocationResult.invocations.length && shouldDeleteMessage) { if (shouldStopGeneration) {
if (Array.isArray(invocationResult.errors) && invocationResult.errors.length) {
ToolManager.showToolCallError(invocationResult.errors); ToolManager.showToolCallError(invocationResult.errors);
}
unblockGeneration(type); unblockGeneration(type);
generatedPromptCache = ''; generatedPromptCache = '';
return; return;

View File

@ -25,6 +25,7 @@ import { isTrueBoolean } from './utils.js';
* @typedef {object} ToolInvocationResult * @typedef {object} ToolInvocationResult
* @property {ToolInvocation[]} invocations Successful tool invocations * @property {ToolInvocation[]} invocations Successful tool invocations
* @property {Error[]} errors Errors that occurred during tool invocation * @property {Error[]} errors Errors that occurred during tool invocation
* @property {string[]} stealthCalls Names of stealth tools that were invoked
*/ */
/** /**
@ -36,6 +37,7 @@ import { isTrueBoolean } from './utils.js';
* @property {function} action - The action to perform when the tool is invoked. * @property {function} action - The action to perform when the tool is invoked.
* @property {function} [formatMessage] - A function to format the tool call message. * @property {function} [formatMessage] - A function to format the tool call message.
* @property {function} [shouldRegister] - A function to determine if the tool should be registered. * @property {function} [shouldRegister] - A function to determine if the tool should be registered.
* @property {boolean} [stealth] - A tool call result will not be shown in the chat. No follow-up generation will be performed.
*/ */
/** /**
@ -147,6 +149,12 @@ class ToolDefinition {
*/ */
#shouldRegister; #shouldRegister;
/**
* A tool call result will not be shown in the chat. No follow-up generation will be performed.
* @type {boolean}
*/
#stealth;
/** /**
* Creates a new ToolDefinition. * Creates a new ToolDefinition.
* @param {string} name A unique name for the tool. * @param {string} name A unique name for the tool.
@ -156,8 +164,9 @@ class ToolDefinition {
* @param {function} action A function that will be called when the tool is executed. * @param {function} action A function that will be called when the tool is executed.
* @param {function} formatMessage A function that will be called to format the tool call toast. * @param {function} formatMessage A function that will be called to format the tool call toast.
* @param {function} shouldRegister A function that will be called to determine if the tool should be registered. * @param {function} shouldRegister A function that will be called to determine if the tool should be registered.
* @param {boolean} stealth A tool call result will not be shown in the chat. No follow-up generation will be performed.
*/ */
constructor(name, displayName, description, parameters, action, formatMessage, shouldRegister) { constructor(name, displayName, description, parameters, action, formatMessage, shouldRegister, stealth) {
this.#name = name; this.#name = name;
this.#displayName = displayName; this.#displayName = displayName;
this.#description = description; this.#description = description;
@ -165,6 +174,7 @@ class ToolDefinition {
this.#action = action; this.#action = action;
this.#formatMessage = formatMessage; this.#formatMessage = formatMessage;
this.#shouldRegister = shouldRegister; this.#shouldRegister = shouldRegister;
this.#stealth = stealth;
} }
/** /**
@ -214,6 +224,10 @@ class ToolDefinition {
get displayName() { get displayName() {
return this.#displayName; return this.#displayName;
} }
get stealth() {
return this.#stealth;
}
} }
/** /**
@ -246,7 +260,7 @@ export class ToolManager {
* Registers a new tool with the tool registry. * Registers a new tool with the tool registry.
* @param {ToolRegistration} tool The tool to register. * @param {ToolRegistration} tool The tool to register.
*/ */
static registerFunctionTool({ name, displayName, description, parameters, action, formatMessage, shouldRegister }) { static registerFunctionTool({ name, displayName, description, parameters, action, formatMessage, shouldRegister, stealth }) {
// Convert WIP arguments // Convert WIP arguments
if (typeof arguments[0] !== 'object') { if (typeof arguments[0] !== 'object') {
[name, description, parameters, action] = arguments; [name, description, parameters, action] = arguments;
@ -256,7 +270,16 @@ export class ToolManager {
console.warn(`[ToolManager] A tool with the name "${name}" has already been registered. The definition will be overwritten.`); console.warn(`[ToolManager] A tool with the name "${name}" has already been registered. The definition will be overwritten.`);
} }
const definition = new ToolDefinition(name, displayName, description, parameters, action, formatMessage, shouldRegister); const definition = new ToolDefinition(
name,
displayName,
description,
parameters,
action,
formatMessage,
shouldRegister,
stealth,
);
this.#tools.set(name, definition); this.#tools.set(name, definition);
console.log('[ToolManager] Registered function tool:', definition); console.log('[ToolManager] Registered function tool:', definition);
} }
@ -302,6 +325,20 @@ export class ToolManager {
} }
} }
/**
* Checks if a tool is a stealth tool.
* @param {string} name The name of the tool to check.
* @returns {boolean} Whether the tool is a stealth tool.
*/
static isStealthTool(name) {
if (!this.#tools.has(name)) {
return false;
}
const tool = this.#tools.get(name);
return !!tool.stealth;
}
/** /**
* Formats a message for a tool call by name. * Formats a message for a tool call by name.
* @param {string} name The name of the tool to format the message for. * @param {string} name The name of the tool to format the message for.
@ -608,6 +645,7 @@ export class ToolManager {
const result = { const result = {
invocations: [], invocations: [],
errors: [], errors: [],
stealthCalls: [],
}; };
const toolCalls = ToolManager.#getToolCallsFromData(data); const toolCalls = ToolManager.#getToolCallsFromData(data);
@ -625,7 +663,7 @@ export class ToolManager {
const parameters = toolCall.function.arguments; const parameters = toolCall.function.arguments;
const name = toolCall.function.name; const name = toolCall.function.name;
const displayName = ToolManager.getDisplayName(name); const displayName = ToolManager.getDisplayName(name);
const isStealth = ToolManager.isStealthTool(name);
const message = await ToolManager.formatToolCallMessage(name, parameters); const message = await ToolManager.formatToolCallMessage(name, parameters);
const toast = message && toastr.info(message, 'Tool Calling', { timeOut: 0 }); const toast = message && toastr.info(message, 'Tool Calling', { timeOut: 0 });
const toolResult = await ToolManager.invokeFunctionTool(name, parameters); const toolResult = await ToolManager.invokeFunctionTool(name, parameters);
@ -638,6 +676,12 @@ export class ToolManager {
continue; continue;
} }
// Don't save stealth tool invocations
if (isStealth) {
result.stealthCalls.push(name);
continue;
}
const invocation = { const invocation = {
id, id,
displayName, displayName,
@ -860,6 +904,14 @@ export class ToolManager {
isRequired: false, isRequired: false,
acceptsMultiple: false, acceptsMultiple: false,
}), }),
SlashCommandNamedArgument.fromProps({
name: 'stealth',
description: 'If true, a tool call result will not be shown in the chat and no follow-up generation will be performed.',
typeList: [ARGUMENT_TYPE.BOOLEAN],
isRequired: false,
acceptsMultiple: false,
defaultValue: String(false),
}),
], ],
unnamedArgumentList: [ unnamedArgumentList: [
SlashCommandArgument.fromProps({ SlashCommandArgument.fromProps({
@ -891,7 +943,7 @@ export class ToolManager {
}; };
} }
const { name, displayName, description, parameters, formatMessage, shouldRegister } = args; const { name, displayName, description, parameters, formatMessage, shouldRegister, stealth } = args;
if (!(action instanceof SlashCommandClosure)) { if (!(action instanceof SlashCommandClosure)) {
throw new Error('The unnamed argument must be a closure.'); throw new Error('The unnamed argument must be a closure.');
@ -927,6 +979,7 @@ export class ToolManager {
action: actionFunc, action: actionFunc,
formatMessage: formatMessageFunc, formatMessage: formatMessageFunc,
shouldRegister: shouldRegisterFunc, shouldRegister: shouldRegisterFunc,
stealth: stealth && isTrueBoolean(String(stealth)),
}); });
return ''; return '';