Tool Calling: add shouldRegister function to tool defintion

This commit is contained in:
Cohee 2024-10-09 03:53:32 +03:00
parent 10f51f5d90
commit bc0f5bf4ce

View File

@ -32,6 +32,7 @@ import { slashCommandReturnHelper } from './slash-commands/SlashCommandReturnHel
* @property {object} parameters - The parameters for the tool.
* @property {function} action - The action to perform when the tool is invoked.
* @property {function} formatMessage - A function to format the tool call message.
* @property {function} shouldRegister - A function to determine if the tool should be registered.
*/
/**
@ -137,6 +138,12 @@ class ToolDefinition {
*/
#formatMessage;
/**
* A function that will be called to determine if the tool should be registered.
* @type {function}
*/
#shouldRegister;
/**
* Creates a new ToolDefinition.
* @param {string} name A unique name for the tool.
@ -145,14 +152,16 @@ class ToolDefinition {
* @param {object} parameters A JSON schema for the parameters that the tool accepts.
* @param {function} action A function that will be called when the tool is executed.
* @param {function} formatMessage A function that will be called to format the tool call toast.
* @param {function} shouldRegister A function that will be called to determine if the tool should be registered.
*/
constructor(name, displayName, description, parameters, action, formatMessage) {
constructor(name, displayName, description, parameters, action, formatMessage, shouldRegister) {
this.#name = name;
this.#displayName = displayName;
this.#description = description;
this.#parameters = parameters;
this.#action = action;
this.#formatMessage = formatMessage;
this.#shouldRegister = shouldRegister;
}
/**
@ -193,6 +202,12 @@ class ToolDefinition {
: `Invoking tool: ${this.#displayName || this.#name}`;
}
async shouldRegister() {
return typeof this.#shouldRegister === 'function'
? await this.#shouldRegister()
: true;
}
get displayName() {
return this.#displayName;
}
@ -228,17 +243,17 @@ export class ToolManager {
* Registers a new tool with the tool registry.
* @param {ToolRegistration} tool The tool to register.
*/
static registerFunctionTool({ name, displayName, description, parameters, action, formatMessage }) {
static registerFunctionTool({ name, displayName, description, parameters, action, formatMessage, shouldRegister }) {
// Convert WIP arguments
if (typeof arguments[0] !== 'object') {
[name, description, parameters, action] = arguments;
}
if (this.#tools.has(name)) {
console.warn(`A tool with the name "${name}" has already been registered. The definition will be overwritten.`);
console.warn(`[ToolManager] A tool with the name "${name}" has already been registered. The definition will be overwritten.`);
}
const definition = new ToolDefinition(name, displayName, description, parameters, action, formatMessage);
const definition = new ToolDefinition(name, displayName, description, parameters, action, formatMessage, shouldRegister);
this.#tools.set(name, definition);
console.log('[ToolManager] Registered function tool:', definition);
}
@ -273,7 +288,7 @@ export class ToolManager {
const result = await tool.invoke(invokeParameters);
return typeof result === 'string' ? result : JSON.stringify(result);
} catch (error) {
console.error(`An error occurred while invoking the tool "${name}":`, error);
console.error(`[ToolManager] An error occurred while invoking the tool "${name}":`, error);
if (error instanceof Error) {
error.cause = name;
@ -300,7 +315,7 @@ export class ToolManager {
const formatParameters = typeof parameters === 'string' ? JSON.parse(parameters) : parameters;
return tool.formatMessage(formatParameters);
} catch (error) {
console.error(`An error occurred while formatting the tool call message for "${name}":`, error);
console.error(`[ToolManager] An error occurred while formatting the tool call message for "${name}":`, error);
return `Invoking tool: ${name}`;
}
}
@ -327,11 +342,16 @@ export class ToolManager {
const tools = [];
for (const tool of ToolManager.tools) {
const register = await tool.shouldRegister();
if (!register) {
console.log('[ToolManager] Skipping tool registration:', tool);
continue;
}
tools.push(tool.toFunctionOpenAI());
}
if (tools.length) {
console.log('Registered function tools:', tools);
console.log('[ToolManager] Registered function tools:', tools);
data['tools'] = tools;
data['tool_choice'] = 'auto';
@ -422,7 +442,7 @@ export class ToolManager {
delete targetToolCall[this.#INPUT_DELTA_KEY];
ToolManager.#applyToolCallDelta(targetToolCall, jsonDelta);
} catch (error) {
console.warn('Failed to apply input JSON delta:', error);
console.warn('[ToolManager] Failed to apply input JSON delta:', error);
}
}
}
@ -564,7 +584,7 @@ export class ToolManager {
continue;
}
console.log('Function tool call:', toolCall);
console.log('[ToolManager] Function tool call:', toolCall);
const id = toolCall.id;
const parameters = toolCall.function.arguments;
const name = toolCall.function.name;
@ -574,7 +594,7 @@ export class ToolManager {
const toast = message && toastr.info(message, 'Tool Calling', { timeOut: 0 });
const toolResult = await ToolManager.invokeFunctionTool(name, parameters);
toastr.clear(toast);
console.log('Function tool result:', result);
console.log('[ToolManager] Function tool result:', result);
// Save a successful invocation
if (toolResult instanceof Error) {