Update tool registration

This commit is contained in:
Cohee 2024-10-04 00:39:28 +03:00
parent 6558b10675
commit 5cf64a2613

View File

@ -5,6 +5,7 @@ import { Popup } from './popup.js';
/**
* @typedef {object} ToolInvocation
* @property {string} id - A unique identifier for the tool invocation.
* @property {string} displayName - The display name of the tool.
* @property {string} name - The name of the tool.
* @property {string} parameters - The parameters for the tool invocation.
* @property {string} result - The result of the tool invocation.
@ -27,6 +28,12 @@ class ToolDefinition {
*/
#name;
/**
* A user-friendly display name for the tool.
* @type {string}
*/
#displayName;
/**
* A description of what the tool does.
* @type {string}
@ -45,18 +52,28 @@ class ToolDefinition {
*/
#action;
/**
* A function that will be called to format the tool call toast.
* @type {function}
*/
#formatMessage;
/**
* Creates a new ToolDefinition.
* @param {string} name A unique name for the tool.
* @param {string} displayName A user-friendly display name for the tool.
* @param {string} description A description of what the tool does.
* @param {object} parameters A JSON schema for the parameters that the tool accepts.
* @param {function} action A function that will be called when the tool is executed.
* @param {function} formatMessage A function that will be called to format the tool call toast.
*/
constructor(name, description, parameters, action) {
constructor(name, displayName, description, parameters, action, formatMessage) {
this.#name = name;
this.#displayName = displayName;
this.#description = description;
this.#parameters = parameters;
this.#action = action;
this.#formatMessage = formatMessage;
}
/**
@ -82,6 +99,21 @@ class ToolDefinition {
async invoke(parameters) {
return await this.#action(parameters);
}
/**
* Formats a message with the tool invocation.
* @param {object} parameters The parameters to pass to the tool.
* @returns {string} The formatted message.
*/
formatMessage(parameters) {
return typeof this.#formatMessage === 'function'
? this.#formatMessage(parameters)
: `Invoking tool: ${this.#displayName || this.#name}`;
}
get displayName() {
return this.#displayName;
}
}
/**
@ -104,17 +136,25 @@ export class ToolManager {
/**
* Registers a new tool with the tool registry.
* @param {string} name The name of the tool.
* @param {string} description A description of what the tool does.
* @param {object} parameters A JSON schema for the parameters that the tool accepts.
* @param {function} action A function that will be called when the tool is executed.
* @param {object} tool The tool to register.
* @param {string} tool.name The name of the tool.
* @param {string} tool.displayName A user-friendly display name for the tool.
* @param {string} tool.description A description of what the tool does.
* @param {object} tool.parameters A JSON schema for the parameters that the tool accepts.
* @param {function} tool.action A function that will be called when the tool is executed.
* @param {function} tool.formatMessage A function that will be called to format the tool call toast.
*/
static registerFunctionTool(name, description, parameters, action) {
static registerFunctionTool({ name, displayName, description, parameters, action, formatMessage }) {
// Convert WIP arguments
if (typeof arguments[0] !== 'object') {
[name, description, parameters, action] = arguments;
}
if (this.#tools.has(name)) {
console.warn(`A tool with the name "${name}" has already been registered. The definition will be overwritten.`);
}
const definition = new ToolDefinition(name, description, parameters, action);
const definition = new ToolDefinition(name, displayName, description, parameters, action, formatMessage);
this.#tools.set(name, definition);
console.log('[ToolManager] Registered function tool:', definition);
}
@ -161,6 +201,35 @@ export class ToolManager {
}
}
static formatToolCallMessage(name, parameters) {
if (!this.#tools.has(name)) {
return `Invoked unknown tool: ${name}`;
}
try {
const tool = this.#tools.get(name);
const formatParameters = typeof parameters === 'string' ? JSON.parse(parameters) : parameters;
return tool.formatMessage(formatParameters);
} catch (error) {
console.error(`An error occurred while formatting the tool call message for "${name}":`, error);
return `Invoking tool: ${name}`;
}
}
/**
* Gets the display name of a tool by name.
* @param {string} name
* @returns {string} The display name of the tool.
*/
static getDisplayName(name) {
if (!this.#tools.has(name)) {
return name;
}
const tool = this.#tools.get(name);
return tool.displayName || name;
}
/**
* Register function tools for the next chat completion request.
* @param {object} data Generation data
@ -352,9 +421,11 @@ export class ToolManager {
const id = toolCall.id;
const parameters = toolCall.function.arguments;
const name = toolCall.function.name;
const displayName = ToolManager.getDisplayName(name);
result.hadToolCalls = true;
const toast = toastr.info(`Invoking function tool: ${name}`);
const message = ToolManager.formatToolCallMessage(name, parameters);
const toast = message && toastr.info(message, 'Tool Calling', { timeOut: 0 });
const toolResult = await ToolManager.invokeFunctionTool(name, parameters);
toastr.clear(toast);
console.log('Function tool result:', result);
@ -365,7 +436,14 @@ export class ToolManager {
continue;
}
result.invocations.push({ id, name, parameters, result: toolResult });
const invocation = {
id,
displayName,
name,
parameters,
result: toolResult,
};
result.invocations.push(invocation);
}
}
@ -414,8 +492,8 @@ export class ToolManager {
codeElement.classList.add('language-json');
data.forEach(i => i.parameters = tryParse(i.parameters));
codeElement.textContent = JSON.stringify(data, null, 2);
const toolNames = data.map(i => i.name).join(', ');
summaryElement.textContent = `Performed tool calls: ${toolNames}`;
const toolNames = data.map(i => i.displayName || i.name).join(', ');
summaryElement.textContent = `Tool calls: ${toolNames}`;
preElement.append(codeElement);
detailsElement.append(summaryElement, preElement);
return detailsElement.outerHTML;