From 8006795897c54180823323357ad5a74c8660fe25 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Wed, 2 Oct 2024 01:00:48 +0300 Subject: [PATCH 01/50] New tool calling framework --- public/global.d.ts | 41 -- public/script.js | 35 +- .../scripts/extensions/expressions/index.js | 52 +-- public/scripts/openai.js | 149 +------ public/scripts/tool-calling.js | 381 ++++++++++++++++++ src/endpoints/backends/chat-completions.js | 22 +- 6 files changed, 427 insertions(+), 253 deletions(-) create mode 100644 public/scripts/tool-calling.js diff --git a/public/global.d.ts b/public/global.d.ts index c8bfe14c2..39ba4ea0c 100644 --- a/public/global.d.ts +++ b/public/global.d.ts @@ -1365,44 +1365,3 @@ declare namespace moment { declare global { const moment: typeof moment; } - -/** - * Callback data for the `LLM_FUNCTION_TOOL_REGISTER` event type that is triggered when a function tool can be registered. - */ -interface FunctionToolRegister { - /** - * The type of generation that is being used - */ - type?: string; - /** - * Generation data, including messages and sampling parameters - */ - data: Record; - /** - * Callback to register an LLM function tool. - */ - registerFunctionTool: typeof registerFunctionTool; -} - -/** - * Callback data for the `LLM_FUNCTION_TOOL_REGISTER` event type that is triggered when a function tool is registered. - * @param name Name of the function tool to register - * @param description Description of the function tool - * @param params JSON schema for the parameters of the function tool - * @param required Whether the function tool should be forced to be used - */ -declare function registerFunctionTool(name: string, description: string, params: object, required: boolean): Promise; - -/** - * Callback data for the `LLM_FUNCTION_TOOL_CALL` event type that is triggered when a function tool is called. - */ -interface FunctionToolCall { - /** - * Name of the function tool to call - */ - name: string; - /** - * JSON object with the parameters to pass to the function tool - */ - arguments: string; -} diff --git a/public/script.js b/public/script.js index 35bfcbab0..fac11b8f4 100644 --- a/public/script.js +++ b/public/script.js @@ -246,6 +246,7 @@ import { initInputMarkdown } from './scripts/input-md-formatting.js'; import { AbortReason } from './scripts/util/AbortReason.js'; import { initSystemPrompts } from './scripts/sysprompt.js'; import { registerExtensionSlashCommands as initExtensionSlashCommands } from './scripts/extensions-slashcommands.js'; +import { ToolManager } from './scripts/tool-calling.js'; //exporting functions and vars for mods export { @@ -463,8 +464,6 @@ export const event_types = { FILE_ATTACHMENT_DELETED: 'file_attachment_deleted', WORLDINFO_FORCE_ACTIVATE: 'worldinfo_force_activate', OPEN_CHARACTER_LIBRARY: 'open_character_library', - LLM_FUNCTION_TOOL_REGISTER: 'llm_function_tool_register', - LLM_FUNCTION_TOOL_CALL: 'llm_function_tool_call', ONLINE_STATUS_CHANGED: 'online_status_changed', IMAGE_SWIPED: 'image_swiped', CONNECTION_PROFILE_LOADED: 'connection_profile_loaded', @@ -2921,6 +2920,7 @@ class StreamingProcessor { this.swipes = []; /** @type {import('./scripts/logprobs.js').TokenLogprobs[]} */ this.messageLogprobs = []; + this.toolCalls = []; } #checkDomElements(messageId) { @@ -3139,7 +3139,7 @@ class StreamingProcessor { } /** - * @returns {Generator<{ text: string, swipes: string[], logprobs: import('./scripts/logprobs.js').TokenLogprobs }, void, void>} + * @returns {Generator<{ text: string, swipes: string[], logprobs: import('./scripts/logprobs.js').TokenLogprobs, toolCalls: any[] }, void, void>} */ *nullStreamingGeneration() { throw new Error('Generation function for streaming is not hooked up'); @@ -3161,12 +3161,13 @@ class StreamingProcessor { try { const sw = new Stopwatch(1000 / power_user.streaming_fps); const timestamps = []; - for await (const { text, swipes, logprobs } of this.generator()) { + for await (const { text, swipes, logprobs, toolCalls } of this.generator()) { timestamps.push(Date.now()); if (this.isStopped) { return; } + this.toolCalls = toolCalls; this.result = text; this.swipes = Array.from(swipes ?? []); if (logprobs) { @@ -4405,6 +4406,20 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro getMessage = continue_mag + getMessage; } + if (ToolManager.isFunctionCallingSupported() && Array.isArray(streamingProcessor.toolCalls) && streamingProcessor.toolCalls.length) { + const invocations = await ToolManager.checkFunctionToolCalls(streamingProcessor.toolCalls); + if (invocations.length) { + const lastMessage = chat[chat.length - 1]; + const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result); + if (shouldDeleteMessage) { + await deleteLastMessage(); + streamingProcessor = null; + } + ToolManager.saveFunctionToolInvocations(invocations); + return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun); + } + } + if (streamingProcessor && !streamingProcessor.isStopped && streamingProcessor.isFinished) { await streamingProcessor.onFinishStreaming(streamingProcessor.messageId, getMessage); streamingProcessor = null; @@ -4440,6 +4455,14 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro throw new Error(data?.response); } + if (ToolManager.isFunctionCallingSupported()) { + const invocations = await ToolManager.checkFunctionToolCalls(data); + if (invocations.length) { + ToolManager.saveFunctionToolInvocations(invocations); + return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun); + } + } + //const getData = await response.json(); let getMessage = extractMessageFromData(data); let title = extractTitleFromData(data); @@ -7853,7 +7876,7 @@ function openAlternateGreetings() { if (menu_type !== 'create') { await createOrEditCharacter(); } - } + }, }); for (let index = 0; index < getArray().length; index++) { @@ -8130,6 +8153,8 @@ window['SillyTavern'].getContext = function () { registerHelper: () => { }, registerMacro: MacrosParser.registerMacro.bind(MacrosParser), unregisterMacro: MacrosParser.unregisterMacro.bind(MacrosParser), + registerFunctionTool: ToolManager.registerFunctionTool.bind(ToolManager), + unregisterFunctionTool: ToolManager.unregisterFunctionTool.bind(ToolManager), registerDebugFunction: registerDebugFunction, /** @deprecated Use renderExtensionTemplateAsync instead. */ renderExtensionTemplate: renderExtensionTemplate, diff --git a/public/scripts/extensions/expressions/index.js b/public/scripts/extensions/expressions/index.js index 4a09e3aeb..078375cf1 100644 --- a/public/scripts/extensions/expressions/index.js +++ b/public/scripts/extensions/expressions/index.js @@ -9,7 +9,6 @@ import { debounce_timeout } from '../../constants.js'; import { SlashCommandParser } from '../../slash-commands/SlashCommandParser.js'; import { SlashCommand } from '../../slash-commands/SlashCommand.js'; import { ARGUMENT_TYPE, SlashCommandArgument, SlashCommandNamedArgument } from '../../slash-commands/SlashCommandArgument.js'; -import { isFunctionCallingSupported } from '../../openai.js'; import { SlashCommandEnumValue, enumTypes } from '../../slash-commands/SlashCommandEnumValue.js'; import { commonEnumProviders } from '../../slash-commands/SlashCommandCommonEnumsProvider.js'; import { slashCommandReturnHelper } from '../../slash-commands/SlashCommandReturnHelper.js'; @@ -21,7 +20,6 @@ const UPDATE_INTERVAL = 2000; const STREAMING_UPDATE_INTERVAL = 10000; const TALKINGCHECK_UPDATE_INTERVAL = 500; const DEFAULT_FALLBACK_EXPRESSION = 'joy'; -const FUNCTION_NAME = 'set_emotion'; const DEFAULT_LLM_PROMPT = 'Ignore previous instructions. Classify the emotion of the last message. Output just one word, e.g. "joy" or "anger". Choose only one of the following labels: {{labels}}'; const DEFAULT_EXPRESSIONS = [ 'talkinghead', @@ -1017,10 +1015,6 @@ async function getLlmPrompt(labels) { return ''; } - if (isFunctionCallingSupported()) { - return ''; - } - const labelsString = labels.map(x => `"${x}"`).join(', '); const prompt = substituteParamsExtended(String(extension_settings.expressions.llmPrompt), { labels: labelsString }); return prompt; @@ -1056,41 +1050,6 @@ function parseLlmResponse(emotionResponse, labels) { throw new Error('Could not parse emotion response ' + emotionResponse); } -/** - * Registers the function tool for the LLM API. - * @param {FunctionToolRegister} args Function tool register arguments. - */ -function onFunctionToolRegister(args) { - if (inApiCall && extension_settings.expressions.api === EXPRESSION_API.llm && isFunctionCallingSupported()) { - // Only trigger on quiet mode - if (args.type !== 'quiet') { - return; - } - - const emotions = DEFAULT_EXPRESSIONS.filter((e) => e != 'talkinghead'); - const jsonSchema = { - $schema: 'http://json-schema.org/draft-04/schema#', - type: 'object', - properties: { - emotion: { - type: 'string', - enum: emotions, - description: `One of the following: ${JSON.stringify(emotions)}`, - }, - }, - required: [ - 'emotion', - ], - }; - args.registerFunctionTool( - FUNCTION_NAME, - substituteParams('Sets the label that best describes the current emotional state of {{char}}. Only select one of the enumerated values.'), - jsonSchema, - true, - ); - } -} - function onTextGenSettingsReady(args) { // Only call if inside an API call if (inApiCall && extension_settings.expressions.api === EXPRESSION_API.llm && isJsonSchemaSupported()) { @@ -1164,18 +1123,9 @@ export async function getExpressionLabel(text, expressionsApi = extension_settin const expressionsList = await getExpressionsList(); const prompt = substituteParamsExtended(customPrompt, { labels: expressionsList }) || await getLlmPrompt(expressionsList); - let functionResult = null; eventSource.once(event_types.TEXT_COMPLETION_SETTINGS_READY, onTextGenSettingsReady); - eventSource.once(event_types.LLM_FUNCTION_TOOL_REGISTER, onFunctionToolRegister); - eventSource.once(event_types.LLM_FUNCTION_TOOL_CALL, (/** @type {FunctionToolCall} */ args) => { - if (args.name !== FUNCTION_NAME) { - return; - } - - functionResult = args?.arguments; - }); const emotionResponse = await generateRaw(text, main_api, false, false, prompt); - return parseLlmResponse(functionResult || emotionResponse, expressionsList); + return parseLlmResponse(emotionResponse, expressionsList); } // Extras default: { diff --git a/public/scripts/openai.js b/public/scripts/openai.js index 20c141316..a6048dce5 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -70,6 +70,7 @@ import { renderTemplateAsync } from './templates.js'; import { SlashCommandEnumValue } from './slash-commands/SlashCommandEnumValue.js'; import { Popup, POPUP_RESULT } from './popup.js'; import { t } from './i18n.js'; +import { ToolManager } from './tool-calling.js'; export { openai_messages_count, @@ -1863,8 +1864,8 @@ async function sendOpenAIRequest(type, messages, signal) { generate_data['seed'] = oai_settings.seed; } - if (isFunctionCallingSupported() && !stream) { - await registerFunctionTools(type, generate_data); + if (!canMultiSwipe && ToolManager.isFunctionCallingSupported()) { + await ToolManager.registerFunctionToolsOpenAI(generate_data); } if (isOAI && oai_settings.openai_model.startsWith('o1-')) { @@ -1911,6 +1912,7 @@ async function sendOpenAIRequest(type, messages, signal) { return async function* streamData() { let text = ''; const swipes = []; + const toolCalls = []; while (true) { const { done, value } = await reader.read(); if (done) return; @@ -1926,7 +1928,9 @@ async function sendOpenAIRequest(type, messages, signal) { text += getStreamingReply(parsed); } - yield { text, swipes: swipes, logprobs: parseChatCompletionLogprobs(parsed) }; + ToolManager.parseToolCalls(toolCalls, parsed); + + yield { text, swipes: swipes, logprobs: parseChatCompletionLogprobs(parsed), toolCalls: toolCalls }; } }; } @@ -1948,147 +1952,10 @@ async function sendOpenAIRequest(type, messages, signal) { delay(1).then(() => saveLogprobsForActiveMessage(logprobs, null)); } - if (isFunctionCallingSupported()) { - await checkFunctionToolCalls(data); - } - return data; } } -/** - * Register function tools for the next chat completion request. - * @param {string} type Generation type - * @param {object} data Generation data - */ -async function registerFunctionTools(type, data) { - let toolChoice = 'auto'; - const tools = []; - - /** - * @type {registerFunctionTool} - */ - const registerFunctionTool = (name, description, parameters, required) => { - tools.push({ - type: 'function', - function: { - name, - description, - parameters, - }, - }); - - if (required) { - toolChoice = 'required'; - } - }; - - /** - * @type {FunctionToolRegister} - */ - const args = { - type, - data, - registerFunctionTool, - }; - - await eventSource.emit(event_types.LLM_FUNCTION_TOOL_REGISTER, args); - - if (tools.length) { - console.log('Registered function tools:', tools); - - data['tools'] = tools; - data['tool_choice'] = toolChoice; - } -} - -async function checkFunctionToolCalls(data) { - const oaiCompat = [ - chat_completion_sources.OPENAI, - chat_completion_sources.CUSTOM, - chat_completion_sources.MISTRALAI, - chat_completion_sources.OPENROUTER, - chat_completion_sources.GROQ, - ]; - if (oaiCompat.includes(oai_settings.chat_completion_source)) { - if (!Array.isArray(data?.choices)) { - return; - } - - // Find a choice with 0-index - const choice = data.choices.find(choice => choice.index === 0); - - if (!choice) { - return; - } - - const toolCalls = choice.message.tool_calls; - - if (!Array.isArray(toolCalls)) { - return; - } - - for (const toolCall of toolCalls) { - if (typeof toolCall.function !== 'object') { - continue; - } - - /** @type {FunctionToolCall} */ - const args = toolCall.function; - console.log('Function tool call:', toolCall); - await eventSource.emit(event_types.LLM_FUNCTION_TOOL_CALL, args); - } - } - - if ([chat_completion_sources.CLAUDE].includes(oai_settings.chat_completion_source)) { - if (!Array.isArray(data?.content)) { - return; - } - - for (const content of data.content) { - if (content.type === 'tool_use') { - /** @type {FunctionToolCall} */ - const args = { name: content.name, arguments: JSON.stringify(content.input) }; - await eventSource.emit(event_types.LLM_FUNCTION_TOOL_CALL, args); - } - } - } - - if ([chat_completion_sources.COHERE].includes(oai_settings.chat_completion_source)) { - if (!Array.isArray(data?.tool_calls)) { - return; - } - - for (const toolCall of data.tool_calls) { - /** @type {FunctionToolCall} */ - const args = { name: toolCall.name, arguments: JSON.stringify(toolCall.parameters) }; - console.log('Function tool call:', toolCall); - await eventSource.emit(event_types.LLM_FUNCTION_TOOL_CALL, args); - } - } -} - -export function isFunctionCallingSupported() { - if (main_api !== 'openai') { - return false; - } - - if (!oai_settings.function_calling) { - return false; - } - - const supportedSources = [ - chat_completion_sources.OPENAI, - chat_completion_sources.COHERE, - chat_completion_sources.CUSTOM, - chat_completion_sources.MISTRALAI, - chat_completion_sources.CLAUDE, - chat_completion_sources.OPENROUTER, - chat_completion_sources.GROQ, - ]; - return supportedSources.includes(oai_settings.chat_completion_source); -} - function getStreamingReply(data) { if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) { return data?.delta?.text || ''; @@ -4019,7 +3886,7 @@ async function onModelChange() { $('#openai_max_context').attr('max', max_32k); } else if (value === 'text-bison-001') { $('#openai_max_context').attr('max', max_8k); - // The ultra endpoints are possibly dead: + // The ultra endpoints are possibly dead: } else if (value.includes('gemini-1.0-ultra') || value === 'gemini-ultra') { $('#openai_max_context').attr('max', max_32k); } else { diff --git a/public/scripts/tool-calling.js b/public/scripts/tool-calling.js new file mode 100644 index 000000000..c2f6c209f --- /dev/null +++ b/public/scripts/tool-calling.js @@ -0,0 +1,381 @@ +import { chat, main_api } from '../script.js'; +import { chat_completion_sources, oai_settings } from './openai.js'; + +/** + * @typedef {object} ToolInvocation + * @property {string} id - A unique identifier for the tool invocation. + * @property {string} name - The name of the tool. + * @property {string} parameters - The parameters for the tool invocation. + * @property {string} result - The result of the tool invocation. + */ + +/** + * A class that represents a tool definition. + */ +class ToolDefinition { + /** + * A unique name for the tool. + * @type {string} + */ + #name; + + /** + * A description of what the tool does. + * @type {string} + */ + #description; + + /** + * A JSON schema for the parameters that the tool accepts. + * @type {object} + */ + #parameters; + + /** + * A function that will be called when the tool is executed. + * @type {function} + */ + #action; + + /** + * Creates a new ToolDefinition. + * @param {string} name A unique name for the tool. + * @param {string} description A description of what the tool does. + * @param {object} parameters A JSON schema for the parameters that the tool accepts. + * @param {function} action A function that will be called when the tool is executed. + */ + constructor(name, description, parameters, action) { + this.#name = name; + this.#description = description; + this.#parameters = parameters; + this.#action = action; + } + + /** + * Converts the ToolDefinition to an OpenAI API representation + * @returns {object} OpenAI API representation of the tool. + */ + toFunctionOpenAI() { + return { + type: 'function', + function: { + name: this.#name, + description: this.#description, + parameters: this.#parameters, + }, + }; + } + + /** + * Invokes the tool with the given parameters. + * @param {object} parameters The parameters to pass to the tool. + * @returns {Promise} The result of the tool's action function. + */ + async invoke(parameters) { + return await this.#action(parameters); + } +} + +/** + * A class that manages the registration and invocation of tools. + */ +export class ToolManager { + /** + * A map of tool names to tool definitions. + * @type {Map} + */ + static #tools = new Map(); + + /** + * Returns an Array of all tools that have been registered. + * @type {ToolDefinition[]} + */ + static get tools() { + return Array.from(this.#tools.values()); + } + + /** + * Registers a new tool with the tool registry. + * @param {string} name The name of the tool. + * @param {string} description A description of what the tool does. + * @param {object} parameters A JSON schema for the parameters that the tool accepts. + * @param {function} action A function that will be called when the tool is executed. + */ + static registerFunctionTool(name, description, parameters, action) { + if (this.#tools.has(name)) { + console.warn(`A tool with the name "${name}" has already been registered. The definition will be overwritten.`); + } + + const definition = new ToolDefinition(name, description, parameters, action); + this.#tools.set(name, definition); + } + + /** + * Removes a tool from the tool registry. + * @param {string} name The name of the tool to unregister. + */ + static unregisterFunctionTool(name) { + if (!this.#tools.has(name)) { + console.warn(`No tool with the name "${name}" has been registered.`); + return; + } + + this.#tools.delete(name); + } + + /** + * Invokes a tool by name. Returns the result of the tool's action function. + * @param {string} name The name of the tool to invoke. + * @param {object} parameters Function parameters. For example, if the tool requires a "name" parameter, you would pass {name: "value"}. + * @returns {Promise} The result of the tool's action function. If an error occurs, null is returned. Non-string results are JSON-stringified. + */ + static async invokeFunctionTool(name, parameters) { + try { + if (!this.#tools.has(name)) { + throw new Error(`No tool with the name "${name}" has been registered.`); + } + + const invokeParameters = typeof parameters === 'string' ? JSON.parse(parameters) : parameters; + const tool = this.#tools.get(name); + const result = await tool.invoke(invokeParameters); + return typeof result === 'string' ? result : JSON.stringify(result); + } catch (error) { + console.error(`An error occurred while invoking the tool "${name}":`, error); + return null; + } + } + + /** + * Register function tools for the next chat completion request. + * @param {object} data Generation data + */ + static async registerFunctionToolsOpenAI(data) { + const tools = []; + + for (const tool of ToolManager.tools) { + tools.push(tool.toFunctionOpenAI()); + } + + if (tools.length) { + console.log('Registered function tools:', tools); + + data['tools'] = tools; + data['tool_choice'] = 'auto'; + } + } + + /** + * Utility function to parse tool calls from a parsed response. + * @param {any[]} toolCalls The tool calls to update. + * @param {any} parsed The parsed response from the OpenAI API. + * @returns {void} + */ + static parseToolCalls(toolCalls, parsed) { + if (!Array.isArray(parsed?.choices)) { + return; + } + for (const choice of parsed.choices) { + const choiceIndex = (typeof choice.index === 'number') ? choice.index : null; + const choiceDelta = choice.delta; + + if (choiceIndex === null || !choiceDelta) { + continue; + } + + const toolCallDeltas = choiceDelta?.tool_calls; + + if (!Array.isArray(toolCallDeltas)) { + continue; + } + + if (!Array.isArray(toolCalls[choiceIndex])) { + toolCalls[choiceIndex] = []; + } + + for (const toolCallDelta of toolCallDeltas) { + const toolCallIndex = (typeof toolCallDelta?.index === 'number') ? toolCallDelta.index : null; + + if (toolCallIndex === null) { + continue; + } + + if (toolCalls[choiceIndex][toolCallIndex] === undefined) { + toolCalls[choiceIndex][toolCallIndex] = {}; + } + + const targetToolCall = toolCalls[choiceIndex][toolCallIndex]; + + ToolManager.#applyToolCallDelta(targetToolCall, toolCallDelta); + } + } + } + + static #applyToolCallDelta(target, delta) { + for (const key in delta) { + if (!delta.hasOwnProperty(key)) continue; + + const deltaValue = delta[key]; + const targetValue = target[key]; + + if (deltaValue === null || deltaValue === undefined) { + target[key] = deltaValue; + continue; + } + + if (typeof deltaValue === 'string') { + if (typeof targetValue === 'string') { + // Concatenate strings + target[key] = targetValue + deltaValue; + } else { + target[key] = deltaValue; + } + } else if (typeof deltaValue === 'object' && !Array.isArray(deltaValue)) { + if (typeof targetValue !== 'object' || targetValue === null || Array.isArray(targetValue)) { + target[key] = {}; + } + // Recursively apply deltas to nested objects + ToolManager.#applyToolCallDelta(target[key], deltaValue); + } else { + // Assign other types directly + target[key] = deltaValue; + } + } + } + + static isFunctionCallingSupported() { + if (main_api !== 'openai') { + return false; + } + + if (!oai_settings.function_calling) { + return false; + } + + const supportedSources = [ + chat_completion_sources.OPENAI, + //chat_completion_sources.COHERE, + chat_completion_sources.CUSTOM, + chat_completion_sources.MISTRALAI, + //chat_completion_sources.CLAUDE, + chat_completion_sources.OPENROUTER, + chat_completion_sources.GROQ, + ]; + return supportedSources.includes(oai_settings.chat_completion_source); + } + + static #getToolCallsFromData(data) { + // Parsed tool calls from streaming data + if (Array.isArray(data) && data.length > 0) { + return data[0]; + } + + // Parsed tool calls from non-streaming data + if (!Array.isArray(data?.choices)) { + return; + } + + // Find a choice with 0-index + const choice = data.choices.find(choice => choice.index === 0); + + if (!choice) { + return; + } + + return choice.message.tool_calls; + } + + /** + * Check for function tool calls in the response data and invoke them. + * @param {any} data Reply data + * @returns {Promise} Successful tool invocations + */ + static async checkFunctionToolCalls(data) { + if (!ToolManager.isFunctionCallingSupported()) { + return []; + } + + /** @type {ToolInvocation[]} */ + const invocations = []; + const toolCalls = ToolManager.#getToolCallsFromData(data); + const oaiCompat = [ + chat_completion_sources.OPENAI, + chat_completion_sources.CUSTOM, + chat_completion_sources.MISTRALAI, + chat_completion_sources.OPENROUTER, + chat_completion_sources.GROQ, + ]; + + if (oaiCompat.includes(oai_settings.chat_completion_source)) { + if (!Array.isArray(toolCalls)) { + return; + } + + for (const toolCall of toolCalls) { + if (typeof toolCall.function !== 'object') { + continue; + } + + console.log('Function tool call:', toolCall); + const id = toolCall.id; + const parameters = toolCall.function.arguments; + const name = toolCall.function.name; + + toastr.info('Invoking function tool: ' + name); + const result = await ToolManager.invokeFunctionTool(name, parameters); + toastr.info('Function tool result: ' + result); + + // Save a successful invocation + if (result) { + invocations.push({ id, name, result, parameters }); + } + } + } + + /* + if ([chat_completion_sources.CLAUDE].includes(oai_settings.chat_completion_source)) { + if (!Array.isArray(data?.content)) { + return; + } + + for (const content of data.content) { + if (content.type === 'tool_use') { + const args = { name: content.name, arguments: JSON.stringify(content.input) }; + } + } + } + */ + + /* + if ([chat_completion_sources.COHERE].includes(oai_settings.chat_completion_source)) { + if (!Array.isArray(data?.tool_calls)) { + return; + } + + for (const toolCall of data.tool_calls) { + const args = { name: toolCall.name, arguments: JSON.stringify(toolCall.parameters) }; + console.log('Function tool call:', toolCall); + } + } + */ + + return invocations; + } + + /** + * Saves function tool invocations to the last user chat message extra metadata. + * @param {ToolInvocation[]} invocations + */ + static saveFunctionToolInvocations(invocations) { + for (let index = chat.length - 1; index >= 0; index--) { + const message = chat[index]; + if (message.is_user) { + if (!message.extra || typeof message.extra !== 'object') { + message.extra = {}; + } + message.extra.tool_invocations = invocations; + debugger; + break; + } + } + } +} diff --git a/src/endpoints/backends/chat-completions.js b/src/endpoints/backends/chat-completions.js index ca86db051..095a65ceb 100644 --- a/src/endpoints/backends/chat-completions.js +++ b/src/endpoints/backends/chat-completions.js @@ -121,18 +121,20 @@ async function sendClaudeRequest(request, response) { ? [{ type: 'text', text: convertedPrompt.systemPrompt, cache_control: { type: 'ephemeral' } }] : convertedPrompt.systemPrompt; } + /* if (Array.isArray(request.body.tools) && request.body.tools.length > 0) { // Claude doesn't do prefills on function calls, and doesn't allow empty messages if (convertedPrompt.messages.length && convertedPrompt.messages[convertedPrompt.messages.length - 1].role === 'assistant') { convertedPrompt.messages.push({ role: 'user', content: '.' }); } additionalHeaders['anthropic-beta'] = 'tools-2024-05-16'; - requestBody.tool_choice = { type: request.body.tool_choice === 'required' ? 'any' : 'auto' }; + requestBody.tool_choice = { type: request.body.tool_choice }; requestBody.tools = request.body.tools .filter(tool => tool.type === 'function') .map(tool => tool.function) .map(fn => ({ name: fn.name, description: fn.description, input_schema: fn.parameters })); } + */ if (enableSystemPromptCache) { additionalHeaders['anthropic-beta'] = 'prompt-caching-2024-07-31'; } @@ -479,7 +481,7 @@ async function sendMistralAIRequest(request, response) { if (Array.isArray(request.body.tools) && request.body.tools.length > 0) { requestBody['tools'] = request.body.tools; - requestBody['tool_choice'] = request.body.tool_choice === 'required' ? 'any' : 'auto'; + requestBody['tool_choice'] = request.body.tool_choice; } const config = { @@ -549,11 +551,13 @@ async function sendCohereRequest(request, response) { }); } + /* if (Array.isArray(request.body.tools) && request.body.tools.length > 0) { tools.push(...convertCohereTools(request.body.tools)); // Can't have both connectors and tools in the same request connectors.splice(0, connectors.length); } + */ // https://docs.cohere.com/reference/chat const requestBody = { @@ -910,18 +914,6 @@ router.post('/generate', jsonParser, function (request, response) { apiKey = readSecret(request.user.directories, SECRET_KEYS.GROQ); headers = {}; bodyParams = {}; - - // 'required' tool choice is not supported by Groq - if (request.body.tool_choice === 'required') { - if (Array.isArray(request.body.tools) && request.body.tools.length > 0) { - request.body.tool_choice = request.body.tools.length > 1 - ? 'auto' : - { type: 'function', function: { name: request.body.tools[0]?.function?.name } }; - - } else { - request.body.tool_choice = 'none'; - } - } } else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.ZEROONEAI) { apiUrl = API_01AI; apiKey = readSecret(request.user.directories, SECRET_KEYS.ZEROONEAI); @@ -958,7 +950,7 @@ router.post('/generate', jsonParser, function (request, response) { controller.abort(); }); - if (!isTextCompletion) { + if (!isTextCompletion && Array.isArray(request.body.tools) && request.body.tools.length > 0) { bodyParams['tools'] = request.body.tools; bodyParams['tool_choice'] = request.body.tool_choice; } From c94c06ed4dac3c4895e4afb5b0e6c1524829e49b Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Wed, 2 Oct 2024 01:45:57 +0300 Subject: [PATCH 02/50] Implement function tool calling for OpenAI --- public/script.js | 4 +-- public/scripts/openai.js | 52 +++++++++++++++++++++++++++++++--- public/scripts/tool-calling.js | 5 ++-- 3 files changed, 52 insertions(+), 9 deletions(-) diff --git a/public/script.js b/public/script.js index fac11b8f4..3312c556f 100644 --- a/public/script.js +++ b/public/script.js @@ -4408,7 +4408,7 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro if (ToolManager.isFunctionCallingSupported() && Array.isArray(streamingProcessor.toolCalls) && streamingProcessor.toolCalls.length) { const invocations = await ToolManager.checkFunctionToolCalls(streamingProcessor.toolCalls); - if (invocations.length) { + if (Array.isArray(invocations) && invocations.length) { const lastMessage = chat[chat.length - 1]; const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result); if (shouldDeleteMessage) { @@ -4457,7 +4457,7 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro if (ToolManager.isFunctionCallingSupported()) { const invocations = await ToolManager.checkFunctionToolCalls(data); - if (invocations.length) { + if (Array.isArray(invocations) && invocations.length) { ToolManager.saveFunctionToolInvocations(invocations); return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun); } diff --git a/public/scripts/openai.js b/public/scripts/openai.js index a6048dce5..8b9c06048 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -454,7 +454,8 @@ function setOpenAIMessages(chat) { if (role == 'user' && oai_settings.wrap_in_quotes) content = `"${content}"`; const name = chat[j]['name']; const image = chat[j]?.extra?.image; - messages[i] = { 'role': role, 'content': content, name: name, 'image': image }; + const invocations = chat[j]?.extra?.tool_invocations; + messages[i] = { 'role': role, 'content': content, name: name, 'image': image, 'invocations': invocations }; j++; } @@ -702,6 +703,7 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul } const imageInlining = isImageInliningSupported(); + const toolCalling = ToolManager.isFunctionCallingSupported(); // Insert chat messages as long as there is budget available const chatPool = [...messages].reverse(); @@ -723,6 +725,24 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul await chatMessage.addImage(chatPrompt.image); } + if (toolCalling && Array.isArray(chatPrompt.invocations)) { + /** @type {import('./tool-calling.js').ToolInvocation[]} */ + const invocations = chatPrompt.invocations.slice().reverse(); + const toolCallMessage = new Message('assistant', undefined, 'toolCall-' + chatMessage.identifier); + toolCallMessage.setToolCalls(invocations); + if (chatCompletion.canAfford(toolCallMessage)) { + for (const invocation of invocations) { + const toolResultMessage = new Message('tool', invocation.result, invocation.id); + const canAfford = chatCompletion.canAfford(toolResultMessage) && chatCompletion.canAfford(toolCallMessage); + if (!canAfford) { + break; + } + chatCompletion.insertAtStart(toolResultMessage, 'chatHistory'); + } + chatCompletion.insertAtStart(toolCallMessage, 'chatHistory'); + } + } + if (chatCompletion.canAfford(chatMessage)) { if (type === 'continue' && oai_settings.continue_prefill && chatPrompt === firstNonInjected) { // in case we are using continue_prefill and the latest message is an assistant message, we want to prepend the users assistant prefill on the message @@ -2193,6 +2213,8 @@ class Message { content; /** @type {string} */ name; + /** @type {object} */ + tool_call = null; /** * @constructor @@ -2217,6 +2239,22 @@ class Message { } } + /** + * Reconstruct the message from a tool invocation. + * @param {import('./tool-calling.js').ToolInvocation[]} invocations + */ + setToolCalls(invocations) { + this.tool_calls = invocations.map(i => ({ + id: i.id, + type: 'function', + function: { + arguments: i.parameters, + name: i.name, + }, + })); + this.tokens = tokenHandler.count({ role: this.role, tool_calls: JSON.stringify(this.tool_calls) }); + } + setName(name) { this.name = name; this.tokens = tokenHandler.count({ role: this.role, content: this.content, name: this.name }); @@ -2564,7 +2602,7 @@ export class ChatCompletion { this.checkTokenBudget(message, message.identifier); const index = this.findMessageIndex(identifier); - if (message.content) { + if (message.content || message.tool_calls) { if ('start' === position) this.messages.collection[index].collection.unshift(message); else if ('end' === position) this.messages.collection[index].collection.push(message); else if (typeof position === 'number') this.messages.collection[index].collection.splice(position, 0, message); @@ -2633,8 +2671,14 @@ export class ChatCompletion { for (let item of this.messages.collection) { if (item instanceof MessageCollection) { chat.push(...item.getChat()); - } else if (item instanceof Message && item.content) { - const message = { role: item.role, content: item.content, ...(item.name ? { name: item.name } : {}) }; + } else if (item instanceof Message && (item.content || item.tool_calls)) { + const message = { + role: item.role, + content: item.content, + ...(item.name ? { name: item.name } : {}), + ...(item.tool_calls ? { tool_calls: item.tool_calls } : {}), + ...(item.role === 'tool' ? { tool_call_id: item.identifier } : {}), + }; chat.push(message); } else { this.log(`Skipping invalid or empty message in collection: ${JSON.stringify(item)}`); diff --git a/public/scripts/tool-calling.js b/public/scripts/tool-calling.js index c2f6c209f..c88528dd7 100644 --- a/public/scripts/tool-calling.js +++ b/public/scripts/tool-calling.js @@ -307,7 +307,7 @@ export class ToolManager { if (oaiCompat.includes(oai_settings.chat_completion_source)) { if (!Array.isArray(toolCalls)) { - return; + return []; } for (const toolCall of toolCalls) { @@ -363,7 +363,7 @@ export class ToolManager { /** * Saves function tool invocations to the last user chat message extra metadata. - * @param {ToolInvocation[]} invocations + * @param {ToolInvocation[]} invocations Successful tool invocations */ static saveFunctionToolInvocations(invocations) { for (let index = chat.length - 1; index >= 0; index--) { @@ -373,7 +373,6 @@ export class ToolManager { message.extra = {}; } message.extra.tool_invocations = invocations; - debugger; break; } } From 68c87f7e7aca14e382d8d7976154629a3f007908 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Wed, 2 Oct 2024 01:53:03 +0300 Subject: [PATCH 03/50] Fix code scanning alert no. 231: Prototype-polluting function Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> --- public/scripts/tool-calling.js | 1 + 1 file changed, 1 insertion(+) diff --git a/public/scripts/tool-calling.js b/public/scripts/tool-calling.js index c88528dd7..c19d1ac7a 100644 --- a/public/scripts/tool-calling.js +++ b/public/scripts/tool-calling.js @@ -213,6 +213,7 @@ export class ToolManager { static #applyToolCallDelta(target, delta) { for (const key in delta) { if (!delta.hasOwnProperty(key)) continue; + if (key === "__proto__" || key === "constructor") continue; const deltaValue = delta[key]; const targetValue = target[key]; From 63724a2b383aaa2bcaa5ca4009f65062f45c39cd Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Wed, 2 Oct 2024 01:54:47 +0300 Subject: [PATCH 04/50] eslint update --- public/scripts/tool-calling.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/public/scripts/tool-calling.js b/public/scripts/tool-calling.js index c19d1ac7a..964093bc3 100644 --- a/public/scripts/tool-calling.js +++ b/public/scripts/tool-calling.js @@ -213,7 +213,7 @@ export class ToolManager { static #applyToolCallDelta(target, delta) { for (const key in delta) { if (!delta.hasOwnProperty(key)) continue; - if (key === "__proto__" || key === "constructor") continue; + if (key === '__proto__' || key === 'constructor') continue; const deltaValue = delta[key]; const targetValue = target[key]; From e8b972042553b27fbce73cb9377067ebaeebc965 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Wed, 2 Oct 2024 01:56:27 +0300 Subject: [PATCH 05/50] Budgeting fix --- public/scripts/openai.js | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/public/scripts/openai.js b/public/scripts/openai.js index 8b9c06048..96821eda4 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -727,18 +727,20 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul if (toolCalling && Array.isArray(chatPrompt.invocations)) { /** @type {import('./tool-calling.js').ToolInvocation[]} */ - const invocations = chatPrompt.invocations.slice().reverse(); + const invocations = chatPrompt.invocations; const toolCallMessage = new Message('assistant', undefined, 'toolCall-' + chatMessage.identifier); toolCallMessage.setToolCalls(invocations); if (chatCompletion.canAfford(toolCallMessage)) { - for (const invocation of invocations) { + chatCompletion.reserveBudget(toolCallMessage); + for (const invocation of invocations.slice().reverse()) { const toolResultMessage = new Message('tool', invocation.result, invocation.id); - const canAfford = chatCompletion.canAfford(toolResultMessage) && chatCompletion.canAfford(toolCallMessage); + const canAfford = chatCompletion.canAfford(toolResultMessage); if (!canAfford) { break; } chatCompletion.insertAtStart(toolResultMessage, 'chatHistory'); } + chatCompletion.freeBudget(toolCallMessage); chatCompletion.insertAtStart(toolCallMessage, 'chatHistory'); } } From 3335dbf1a75dd0ffeaede2239c64fb5b85425f31 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Wed, 2 Oct 2024 01:59:53 +0300 Subject: [PATCH 06/50] Add empty tool calls to streaming processors --- public/scripts/kai-settings.js | 2 +- public/scripts/nai-settings.js | 2 +- public/scripts/textgen-settings.js | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/public/scripts/kai-settings.js b/public/scripts/kai-settings.js index 27a204c42..6efadce87 100644 --- a/public/scripts/kai-settings.js +++ b/public/scripts/kai-settings.js @@ -188,7 +188,7 @@ export async function generateKoboldWithStreaming(generate_data, signal) { if (data?.token) { text += data.token; } - yield { text, swipes: [] }; + yield { text, swipes: [], toolCalls: [] }; } }; } diff --git a/public/scripts/nai-settings.js b/public/scripts/nai-settings.js index fc4a042c5..040c6f203 100644 --- a/public/scripts/nai-settings.js +++ b/public/scripts/nai-settings.js @@ -746,7 +746,7 @@ export async function generateNovelWithStreaming(generate_data, signal) { text += data.token; } - yield { text, swipes: [], logprobs: parseNovelAILogprobs(data.logprobs) }; + yield { text, swipes: [], logprobs: parseNovelAILogprobs(data.logprobs), toolCalls: [] }; } }; } diff --git a/public/scripts/textgen-settings.js b/public/scripts/textgen-settings.js index 9403555bd..c021c9682 100644 --- a/public/scripts/textgen-settings.js +++ b/public/scripts/textgen-settings.js @@ -916,6 +916,7 @@ async function generateTextGenWithStreaming(generate_data, signal) { /** @type {import('./logprobs.js').TokenLogprobs | null} */ let logprobs = null; const swipes = []; + const toolCalls = []; while (true) { const { done, value } = await reader.read(); if (done) return; @@ -934,7 +935,7 @@ async function generateTextGenWithStreaming(generate_data, signal) { logprobs = parseTextgenLogprobs(newText, data.choices?.[0]?.logprobs || data?.completion_probabilities); } - yield { text, swipes, logprobs }; + yield { text, swipes, logprobs, toolCalls }; } }; } From 0f8c1fa95d7bdfea5d69ffed3c0ee76eef444886 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Wed, 2 Oct 2024 22:17:27 +0300 Subject: [PATCH 07/50] Save tool calls to visible chats. --- public/script.js | 26 ++++++++------ public/scripts/openai.js | 23 +++++++----- public/scripts/tool-calling.js | 66 +++++++++++++++++++++------------- 3 files changed, 71 insertions(+), 44 deletions(-) diff --git a/public/script.js b/public/script.js index 3312c556f..1651d786e 100644 --- a/public/script.js +++ b/public/script.js @@ -3571,7 +3571,9 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro } // Collect messages with usable content - let coreChat = chat.filter(x => !x.is_system); + const canUseTools = ToolManager.isToolCallingSupported(); + const canPerformToolCalls = !dryRun && ToolManager.canPerformToolCalls(type); + let coreChat = chat.filter(x => !x.is_system || (canUseTools && Array.isArray(x.extra?.tool_invocations))); if (type === 'swipe') { coreChat.pop(); } @@ -4406,8 +4408,8 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro getMessage = continue_mag + getMessage; } - if (ToolManager.isFunctionCallingSupported() && Array.isArray(streamingProcessor.toolCalls) && streamingProcessor.toolCalls.length) { - const invocations = await ToolManager.checkFunctionToolCalls(streamingProcessor.toolCalls); + if (canPerformToolCalls && Array.isArray(streamingProcessor.toolCalls) && streamingProcessor.toolCalls.length) { + const invocations = await ToolManager.invokeFunctionTools(streamingProcessor.toolCalls); if (Array.isArray(invocations) && invocations.length) { const lastMessage = chat[chat.length - 1]; const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result); @@ -4455,14 +4457,6 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro throw new Error(data?.response); } - if (ToolManager.isFunctionCallingSupported()) { - const invocations = await ToolManager.checkFunctionToolCalls(data); - if (Array.isArray(invocations) && invocations.length) { - ToolManager.saveFunctionToolInvocations(invocations); - return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun); - } - } - //const getData = await response.json(); let getMessage = extractMessageFromData(data); let title = extractTitleFromData(data); @@ -4502,6 +4496,16 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro parseAndSaveLogprobs(data, continue_mag); } + if (canPerformToolCalls) { + const invocations = await ToolManager.invokeFunctionTools(data); + if (Array.isArray(invocations) && invocations.length) { + const shouldDeleteMessage = ['', '...'].includes(getMessage); + shouldDeleteMessage && await deleteLastMessage(); + ToolManager.saveFunctionToolInvocations(invocations); + return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun); + } + } + if (type !== 'quiet') { playMessageSound(); } diff --git a/public/scripts/openai.js b/public/scripts/openai.js index 96821eda4..e9216eccf 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -703,7 +703,7 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul } const imageInlining = isImageInliningSupported(); - const toolCalling = ToolManager.isFunctionCallingSupported(); + const canUseTools = ToolManager.isToolCallingSupported(); // Insert chat messages as long as there is budget available const chatPool = [...messages].reverse(); @@ -725,10 +725,10 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul await chatMessage.addImage(chatPrompt.image); } - if (toolCalling && Array.isArray(chatPrompt.invocations)) { + if (canUseTools && Array.isArray(chatPrompt.invocations)) { /** @type {import('./tool-calling.js').ToolInvocation[]} */ const invocations = chatPrompt.invocations; - const toolCallMessage = new Message('assistant', undefined, 'toolCall-' + chatMessage.identifier); + const toolCallMessage = new Message(chatMessage.role, undefined, 'toolCall-' + chatMessage.identifier); toolCallMessage.setToolCalls(invocations); if (chatCompletion.canAfford(toolCallMessage)) { chatCompletion.reserveBudget(toolCallMessage); @@ -1285,7 +1285,7 @@ export async function prepareOpenAIMessages({ const eventData = { chat, dryRun }; await eventSource.emit(event_types.CHAT_COMPLETION_PROMPT_READY, eventData); - openai_messages_count = chat.filter(x => x?.role === 'user' || x?.role === 'assistant')?.length || 0; + openai_messages_count = chat.filter(x => !x?.tool_calls && (x?.role === 'user' || x?.role === 'assistant'))?.length || 0; return [chat, promptManager.tokenHandler.counts]; } @@ -1886,7 +1886,7 @@ async function sendOpenAIRequest(type, messages, signal) { generate_data['seed'] = oai_settings.seed; } - if (!canMultiSwipe && ToolManager.isFunctionCallingSupported()) { + if (!canMultiSwipe && ToolManager.canPerformToolCalls(type)) { await ToolManager.registerFunctionToolsOpenAI(generate_data); } @@ -2393,13 +2393,20 @@ class MessageCollection { } /** - * Get chat in the format of {role, name, content}. + * Get chat in the format of {role, name, content, tool_calls}. * @returns {Array} Array of objects with role, name, and content properties. */ getChat() { return this.collection.reduce((acc, message) => { - const name = message.name; - if (message.content) acc.push({ role: message.role, ...(name && { name }), content: message.content }); + if (message.content || message.tool_calls) { + acc.push({ + role: message.role, + content: message.content, + ...(message.name && { name: message.name }), + ...(message.tool_calls && { tool_calls: message.tool_calls }), + ...(message.role === 'tool' && { tool_call_id: message.identifier }), + }); + } return acc; }, []); } diff --git a/public/scripts/tool-calling.js b/public/scripts/tool-calling.js index 964093bc3..def0c078a 100644 --- a/public/scripts/tool-calling.js +++ b/public/scripts/tool-calling.js @@ -1,4 +1,4 @@ -import { chat, main_api } from '../script.js'; +import { addOneMessage, chat, main_api, system_avatar, systemUserName } from '../script.js'; import { chat_completion_sources, oai_settings } from './openai.js'; /** @@ -243,12 +243,12 @@ export class ToolManager { } } - static isFunctionCallingSupported() { - if (main_api !== 'openai') { - return false; - } - - if (!oai_settings.function_calling) { + /** + * Checks if tool calling is supported for the current settings and generation type. + * @returns {boolean} Whether tool calling is supported for the given type + */ + static isToolCallingSupported() { + if (main_api !== 'openai' || !oai_settings.function_calling) { return false; } @@ -264,6 +264,22 @@ export class ToolManager { return supportedSources.includes(oai_settings.chat_completion_source); } + /** + * Checks if tool calls can be performed for the current settings and generation type. + * @param {string} type Generation type + * @returns {boolean} Whether tool calls can be performed for the given type + */ + static canPerformToolCalls(type) { + const noToolCallTypes = ['swipe', 'impersonate', 'quiet', 'continue']; + const isSupported = ToolManager.isToolCallingSupported(); + return isSupported && !noToolCallTypes.includes(type); + } + + /** + * Utility function to get tool calls from the response data. + * @param {any} data Response data + * @returns {any[]} Tool calls from the response data + */ static #getToolCallsFromData(data) { // Parsed tool calls from streaming data if (Array.isArray(data) && data.length > 0) { @@ -290,15 +306,11 @@ export class ToolManager { * @param {any} data Reply data * @returns {Promise} Successful tool invocations */ - static async checkFunctionToolCalls(data) { - if (!ToolManager.isFunctionCallingSupported()) { - return []; - } - + static async invokeFunctionTools(data) { /** @type {ToolInvocation[]} */ const invocations = []; const toolCalls = ToolManager.#getToolCallsFromData(data); - const oaiCompat = [ + const oaiCompatibleSources = [ chat_completion_sources.OPENAI, chat_completion_sources.CUSTOM, chat_completion_sources.MISTRALAI, @@ -306,7 +318,7 @@ export class ToolManager { chat_completion_sources.GROQ, ]; - if (oaiCompat.includes(oai_settings.chat_completion_source)) { + if (oaiCompatibleSources.includes(oai_settings.chat_completion_source)) { if (!Array.isArray(toolCalls)) { return []; } @@ -323,7 +335,7 @@ export class ToolManager { toastr.info('Invoking function tool: ' + name); const result = await ToolManager.invokeFunctionTool(name, parameters); - toastr.info('Function tool result: ' + result); + console.log('Function tool result:', result); // Save a successful invocation if (result) { @@ -367,15 +379,19 @@ export class ToolManager { * @param {ToolInvocation[]} invocations Successful tool invocations */ static saveFunctionToolInvocations(invocations) { - for (let index = chat.length - 1; index >= 0; index--) { - const message = chat[index]; - if (message.is_user) { - if (!message.extra || typeof message.extra !== 'object') { - message.extra = {}; - } - message.extra.tool_invocations = invocations; - break; - } - } + const toolNames = invocations.map(i => i.name).join(', '); + const message = { + name: systemUserName, + force_avatar: system_avatar, + is_system: true, + is_user: false, + mes: `Performed tool calls: ${toolNames}`, + extra: { + isSmallSys: true, + tool_invocations: invocations, + }, + }; + chat.push(message); + addOneMessage(message); } } From 2dad86e0769f803f674995a3b15ca48369144c36 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Wed, 2 Oct 2024 23:12:49 +0300 Subject: [PATCH 08/50] Delete empty message before tool invocations --- public/script.js | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/public/script.js b/public/script.js index c3aa593e8..3a38a04a9 100644 --- a/public/script.js +++ b/public/script.js @@ -4409,14 +4409,12 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro } if (canPerformToolCalls && Array.isArray(streamingProcessor.toolCalls) && streamingProcessor.toolCalls.length) { + const lastMessage = chat[chat.length - 1]; + const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result); + shouldDeleteMessage && await deleteLastMessage(); const invocations = await ToolManager.invokeFunctionTools(streamingProcessor.toolCalls); if (Array.isArray(invocations) && invocations.length) { - const lastMessage = chat[chat.length - 1]; - const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result); - if (shouldDeleteMessage) { - await deleteLastMessage(); - streamingProcessor = null; - } + streamingProcessor = null; ToolManager.saveFunctionToolInvocations(invocations); return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun); } @@ -4497,10 +4495,10 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro } if (canPerformToolCalls) { + const shouldDeleteMessage = ['', '...'].includes(getMessage); + shouldDeleteMessage && await deleteLastMessage(); const invocations = await ToolManager.invokeFunctionTools(data); if (Array.isArray(invocations) && invocations.length) { - const shouldDeleteMessage = ['', '...'].includes(getMessage); - shouldDeleteMessage && await deleteLastMessage(); ToolManager.saveFunctionToolInvocations(invocations); return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun); } From 2b7c03f3b0c9662471880b25b1ee83ea90f51860 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Wed, 2 Oct 2024 23:13:11 +0300 Subject: [PATCH 09/50] Nicely format message for tool calls --- public/scripts/tool-calling.js | 26 +++++++++++++++++++++++--- public/style.css | 9 +++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/public/scripts/tool-calling.js b/public/scripts/tool-calling.js index def0c078a..b3c8ac1b1 100644 --- a/public/scripts/tool-calling.js +++ b/public/scripts/tool-calling.js @@ -108,6 +108,7 @@ export class ToolManager { const definition = new ToolDefinition(name, description, parameters, action); this.#tools.set(name, definition); + console.log('[ToolManager] Registered function tool:', definition); } /** @@ -121,6 +122,7 @@ export class ToolManager { } this.#tools.delete(name); + console.log(`[ToolManager] Unregistered function tool: ${name}`); } /** @@ -333,8 +335,9 @@ export class ToolManager { const parameters = toolCall.function.arguments; const name = toolCall.function.name; - toastr.info('Invoking function tool: ' + name); + const toast = toastr.info(`Invoking function tool: ${name}`); const result = await ToolManager.invokeFunctionTool(name, parameters); + toastr.clear(toast); console.log('Function tool result:', result); // Save a successful invocation @@ -374,18 +377,35 @@ export class ToolManager { return invocations; } + /** + * Formats a message with tool invocations. + * @param {ToolInvocation[]} invocations Tool invocations. + * @returns {string} Formatted message with tool invocations. + */ + static #formatMessage(invocations) { + const toolNames = invocations.map(i => i.name).join(', '); + const detailsElement = document.createElement('details'); + const summaryElement = document.createElement('summary'); + const preElement = document.createElement('pre'); + const codeElement = document.createElement('code'); + codeElement.textContent = JSON.stringify(invocations, null, 2); + summaryElement.textContent = `Performed tool calls: ${toolNames}`; + preElement.append(codeElement); + detailsElement.append(summaryElement, preElement); + return detailsElement.outerHTML; + } + /** * Saves function tool invocations to the last user chat message extra metadata. * @param {ToolInvocation[]} invocations Successful tool invocations */ static saveFunctionToolInvocations(invocations) { - const toolNames = invocations.map(i => i.name).join(', '); const message = { name: systemUserName, force_avatar: system_avatar, is_system: true, is_user: false, - mes: `Performed tool calls: ${toolNames}`, + mes: ToolManager.#formatMessage(invocations), extra: { isSmallSys: true, tool_invocations: invocations, diff --git a/public/style.css b/public/style.css index 01a49ab9f..19a9bb738 100644 --- a/public/style.css +++ b/public/style.css @@ -418,6 +418,15 @@ small { text-align: center; } +.mes.smallSysMes pre { + text-align: initial; + word-break: break-all; +} + +.mes.smallSysMes summary { + cursor: pointer; +} + .mes.smallSysMes .mes_text p:last-child { margin: 0; } From 1e076a3e4367cc228bb63cadaed3812298897308 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Wed, 2 Oct 2024 23:32:29 +0300 Subject: [PATCH 10/50] Prettify displayed message --- public/scripts/tool-calling.js | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/public/scripts/tool-calling.js b/public/scripts/tool-calling.js index b3c8ac1b1..dd1d1d8c4 100644 --- a/public/scripts/tool-calling.js +++ b/public/scripts/tool-calling.js @@ -342,7 +342,7 @@ export class ToolManager { // Save a successful invocation if (result) { - invocations.push({ id, name, result, parameters }); + invocations.push({ id, name, parameters, result }); } } } @@ -383,12 +383,16 @@ export class ToolManager { * @returns {string} Formatted message with tool invocations. */ static #formatMessage(invocations) { - const toolNames = invocations.map(i => i.name).join(', '); + const tryParse = (x) => { try { return JSON.parse(x); } catch { return x; } }; + const data = structuredClone(invocations); const detailsElement = document.createElement('details'); const summaryElement = document.createElement('summary'); const preElement = document.createElement('pre'); const codeElement = document.createElement('code'); - codeElement.textContent = JSON.stringify(invocations, null, 2); + codeElement.classList.add('language-json'); + data.forEach(i => i.parameters = tryParse(i.parameters)); + codeElement.textContent = JSON.stringify(data, null, 2); + const toolNames = data.map(i => i.name).join(', '); summaryElement.textContent = `Performed tool calls: ${toolNames}`; preElement.append(codeElement); detailsElement.append(summaryElement, preElement); From da9200c82ed39f790b3ff982e796f33aeaec3ae1 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Thu, 3 Oct 2024 12:59:59 +0000 Subject: [PATCH 11/50] Skip adding tool messages as regular chats --- public/scripts/openai.js | 2 ++ 1 file changed, 2 insertions(+) diff --git a/public/scripts/openai.js b/public/scripts/openai.js index e9216eccf..9ab05b467 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -743,6 +743,8 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul chatCompletion.freeBudget(toolCallMessage); chatCompletion.insertAtStart(toolCallMessage, 'chatHistory'); } + + continue; } if (chatCompletion.canAfford(chatMessage)) { From 90809852c2a7e1ecc133003b536d237908e1f498 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Thu, 3 Oct 2024 13:23:53 +0000 Subject: [PATCH 12/50] Hide message on streaming tool calls --- public/script.js | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/public/script.js b/public/script.js index 3a38a04a9..e3c7bf1ba 100644 --- a/public/script.js +++ b/public/script.js @@ -2932,6 +2932,13 @@ class StreamingProcessor { } } + #updateMessageBlockVisibility() { + if (this.messageDom instanceof HTMLElement && Array.isArray(this.toolCalls) && this.toolCalls.length > 0) { + const shouldHide = ['', '...'].includes(this.result); + this.messageDom.classList.toggle('displayNone', shouldHide); + } + } + showMessageButtons(messageId) { if (messageId == -1) { return; @@ -2997,6 +3004,7 @@ class StreamingProcessor { } else { this.#checkDomElements(messageId); + this.#updateMessageBlockVisibility(); const currentTime = new Date(); // Don't waste time calculating token count for streaming const currentTokenCount = isFinal && power_user.message_token_count_enabled ? getTokenCount(processedText, 0) : 0; From 6558b106754a73faff6ca171254192ed9a1157e9 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Fri, 4 Oct 2024 00:11:36 +0300 Subject: [PATCH 13/50] Show an error when all tools fail --- public/script.js | 27 +++++++++++++---- public/scripts/tool-calling.js | 54 +++++++++++++++++++++++++++------- public/style.css | 1 + 3 files changed, 66 insertions(+), 16 deletions(-) diff --git a/public/script.js b/public/script.js index e3c7bf1ba..4d59bde08 100644 --- a/public/script.js +++ b/public/script.js @@ -4420,10 +4420,18 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro const lastMessage = chat[chat.length - 1]; const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result); shouldDeleteMessage && await deleteLastMessage(); - const invocations = await ToolManager.invokeFunctionTools(streamingProcessor.toolCalls); - if (Array.isArray(invocations) && invocations.length) { + const invocationResult = await ToolManager.invokeFunctionTools(streamingProcessor.toolCalls); + if (invocationResult.hadToolCalls) { + if (!invocationResult.invocations.length && shouldDeleteMessage) { + ToolManager.showToolCallError(invocationResult.errors); + unblockGeneration(type); + generatedPromptCache = ''; + streamingProcessor = null; + return; + } + streamingProcessor = null; - ToolManager.saveFunctionToolInvocations(invocations); + ToolManager.saveFunctionToolInvocations(invocationResult.invocations); return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun); } } @@ -4505,9 +4513,16 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro if (canPerformToolCalls) { const shouldDeleteMessage = ['', '...'].includes(getMessage); shouldDeleteMessage && await deleteLastMessage(); - const invocations = await ToolManager.invokeFunctionTools(data); - if (Array.isArray(invocations) && invocations.length) { - ToolManager.saveFunctionToolInvocations(invocations); + const invocationResult = await ToolManager.invokeFunctionTools(data); + if (invocationResult.hadToolCalls) { + if (!invocationResult.invocations.length && shouldDeleteMessage) { + ToolManager.showToolCallError(invocationResult.errors); + unblockGeneration(type); + generatedPromptCache = ''; + return; + } + + ToolManager.saveFunctionToolInvocations(invocationResult.invocations); return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun); } } diff --git a/public/scripts/tool-calling.js b/public/scripts/tool-calling.js index dd1d1d8c4..bd8c572f8 100644 --- a/public/scripts/tool-calling.js +++ b/public/scripts/tool-calling.js @@ -1,5 +1,6 @@ import { addOneMessage, chat, main_api, system_avatar, systemUserName } from '../script.js'; import { chat_completion_sources, oai_settings } from './openai.js'; +import { Popup } from './popup.js'; /** * @typedef {object} ToolInvocation @@ -9,6 +10,13 @@ import { chat_completion_sources, oai_settings } from './openai.js'; * @property {string} result - The result of the tool invocation. */ +/** + * @typedef {object} ToolInvocationResult + * @property {ToolInvocation[]} invocations Successful tool invocations + * @property {boolean} hadToolCalls Whether any tool calls were found + * @property {Error[]} errors Errors that occurred during tool invocation + */ + /** * A class that represents a tool definition. */ @@ -129,7 +137,7 @@ export class ToolManager { * Invokes a tool by name. Returns the result of the tool's action function. * @param {string} name The name of the tool to invoke. * @param {object} parameters Function parameters. For example, if the tool requires a "name" parameter, you would pass {name: "value"}. - * @returns {Promise} The result of the tool's action function. If an error occurs, null is returned. Non-string results are JSON-stringified. + * @returns {Promise} The result of the tool's action function. If an error occurs, null is returned. Non-string results are JSON-stringified. */ static async invokeFunctionTool(name, parameters) { try { @@ -143,7 +151,13 @@ export class ToolManager { return typeof result === 'string' ? result : JSON.stringify(result); } catch (error) { console.error(`An error occurred while invoking the tool "${name}":`, error); - return null; + + if (error instanceof Error) { + error.cause = name; + return error; + } + + return new Error('Unknown error occurred while invoking the tool.', { cause: name }); } } @@ -306,11 +320,15 @@ export class ToolManager { /** * Check for function tool calls in the response data and invoke them. * @param {any} data Reply data - * @returns {Promise} Successful tool invocations + * @returns {Promise} Successful tool invocations */ static async invokeFunctionTools(data) { - /** @type {ToolInvocation[]} */ - const invocations = []; + /** @type {ToolInvocationResult} */ + const result = { + invocations: [], + hadToolCalls: false, + errors: [], + }; const toolCalls = ToolManager.#getToolCallsFromData(data); const oaiCompatibleSources = [ chat_completion_sources.OPENAI, @@ -322,7 +340,7 @@ export class ToolManager { if (oaiCompatibleSources.includes(oai_settings.chat_completion_source)) { if (!Array.isArray(toolCalls)) { - return []; + return result; } for (const toolCall of toolCalls) { @@ -334,16 +352,20 @@ export class ToolManager { const id = toolCall.id; const parameters = toolCall.function.arguments; const name = toolCall.function.name; + result.hadToolCalls = true; const toast = toastr.info(`Invoking function tool: ${name}`); - const result = await ToolManager.invokeFunctionTool(name, parameters); + const toolResult = await ToolManager.invokeFunctionTool(name, parameters); toastr.clear(toast); console.log('Function tool result:', result); // Save a successful invocation - if (result) { - invocations.push({ id, name, parameters, result }); + if (toolResult instanceof Error) { + result.errors.push(toolResult); + continue; } + + result.invocations.push({ id, name, parameters, result: toolResult }); } } @@ -374,7 +396,7 @@ export class ToolManager { } */ - return invocations; + return result; } /** @@ -418,4 +440,16 @@ export class ToolManager { chat.push(message); addOneMessage(message); } + + /** + * Shows an error message for tool calls. + * @param {Error[]} errors Errors that occurred during tool invocation + * @returns {void} + */ + static showToolCallError(errors) { + toastr.error('An error occurred while invoking function tools. Click here for more details.', 'Tool Calling', { + onclick: () => Popup.show.text('Tool Calling Errors', DOMPurify.sanitize(errors.map(e => `${e.cause}: ${e.message}`).join('
'))), + timeOut: 5000, + }); + } } diff --git a/public/style.css b/public/style.css index 19a9bb738..96ae504b7 100644 --- a/public/style.css +++ b/public/style.css @@ -421,6 +421,7 @@ small { .mes.smallSysMes pre { text-align: initial; word-break: break-all; + margin-top: 5px; } .mes.smallSysMes summary { From 5cf64a2613252cc9ab8557fe145fab7546ea56ec Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Fri, 4 Oct 2024 00:39:28 +0300 Subject: [PATCH 14/50] Update tool registration --- public/scripts/tool-calling.js | 100 +++++++++++++++++++++++++++++---- 1 file changed, 89 insertions(+), 11 deletions(-) diff --git a/public/scripts/tool-calling.js b/public/scripts/tool-calling.js index bd8c572f8..0eaada477 100644 --- a/public/scripts/tool-calling.js +++ b/public/scripts/tool-calling.js @@ -5,6 +5,7 @@ import { Popup } from './popup.js'; /** * @typedef {object} ToolInvocation * @property {string} id - A unique identifier for the tool invocation. + * @property {string} displayName - The display name of the tool. * @property {string} name - The name of the tool. * @property {string} parameters - The parameters for the tool invocation. * @property {string} result - The result of the tool invocation. @@ -27,6 +28,12 @@ class ToolDefinition { */ #name; + /** + * A user-friendly display name for the tool. + * @type {string} + */ + #displayName; + /** * A description of what the tool does. * @type {string} @@ -45,18 +52,28 @@ class ToolDefinition { */ #action; + /** + * A function that will be called to format the tool call toast. + * @type {function} + */ + #formatMessage; + /** * Creates a new ToolDefinition. * @param {string} name A unique name for the tool. + * @param {string} displayName A user-friendly display name for the tool. * @param {string} description A description of what the tool does. * @param {object} parameters A JSON schema for the parameters that the tool accepts. * @param {function} action A function that will be called when the tool is executed. + * @param {function} formatMessage A function that will be called to format the tool call toast. */ - constructor(name, description, parameters, action) { + constructor(name, displayName, description, parameters, action, formatMessage) { this.#name = name; + this.#displayName = displayName; this.#description = description; this.#parameters = parameters; this.#action = action; + this.#formatMessage = formatMessage; } /** @@ -82,6 +99,21 @@ class ToolDefinition { async invoke(parameters) { return await this.#action(parameters); } + + /** + * Formats a message with the tool invocation. + * @param {object} parameters The parameters to pass to the tool. + * @returns {string} The formatted message. + */ + formatMessage(parameters) { + return typeof this.#formatMessage === 'function' + ? this.#formatMessage(parameters) + : `Invoking tool: ${this.#displayName || this.#name}`; + } + + get displayName() { + return this.#displayName; + } } /** @@ -104,17 +136,25 @@ export class ToolManager { /** * Registers a new tool with the tool registry. - * @param {string} name The name of the tool. - * @param {string} description A description of what the tool does. - * @param {object} parameters A JSON schema for the parameters that the tool accepts. - * @param {function} action A function that will be called when the tool is executed. + * @param {object} tool The tool to register. + * @param {string} tool.name The name of the tool. + * @param {string} tool.displayName A user-friendly display name for the tool. + * @param {string} tool.description A description of what the tool does. + * @param {object} tool.parameters A JSON schema for the parameters that the tool accepts. + * @param {function} tool.action A function that will be called when the tool is executed. + * @param {function} tool.formatMessage A function that will be called to format the tool call toast. */ - static registerFunctionTool(name, description, parameters, action) { + static registerFunctionTool({ name, displayName, description, parameters, action, formatMessage }) { + // Convert WIP arguments + if (typeof arguments[0] !== 'object') { + [name, description, parameters, action] = arguments; + } + if (this.#tools.has(name)) { console.warn(`A tool with the name "${name}" has already been registered. The definition will be overwritten.`); } - const definition = new ToolDefinition(name, description, parameters, action); + const definition = new ToolDefinition(name, displayName, description, parameters, action, formatMessage); this.#tools.set(name, definition); console.log('[ToolManager] Registered function tool:', definition); } @@ -161,6 +201,35 @@ export class ToolManager { } } + static formatToolCallMessage(name, parameters) { + if (!this.#tools.has(name)) { + return `Invoked unknown tool: ${name}`; + } + + try { + const tool = this.#tools.get(name); + const formatParameters = typeof parameters === 'string' ? JSON.parse(parameters) : parameters; + return tool.formatMessage(formatParameters); + } catch (error) { + console.error(`An error occurred while formatting the tool call message for "${name}":`, error); + return `Invoking tool: ${name}`; + } + } + + /** + * Gets the display name of a tool by name. + * @param {string} name + * @returns {string} The display name of the tool. + */ + static getDisplayName(name) { + if (!this.#tools.has(name)) { + return name; + } + + const tool = this.#tools.get(name); + return tool.displayName || name; + } + /** * Register function tools for the next chat completion request. * @param {object} data Generation data @@ -352,9 +421,11 @@ export class ToolManager { const id = toolCall.id; const parameters = toolCall.function.arguments; const name = toolCall.function.name; + const displayName = ToolManager.getDisplayName(name); result.hadToolCalls = true; - const toast = toastr.info(`Invoking function tool: ${name}`); + const message = ToolManager.formatToolCallMessage(name, parameters); + const toast = message && toastr.info(message, 'Tool Calling', { timeOut: 0 }); const toolResult = await ToolManager.invokeFunctionTool(name, parameters); toastr.clear(toast); console.log('Function tool result:', result); @@ -365,7 +436,14 @@ export class ToolManager { continue; } - result.invocations.push({ id, name, parameters, result: toolResult }); + const invocation = { + id, + displayName, + name, + parameters, + result: toolResult, + }; + result.invocations.push(invocation); } } @@ -414,8 +492,8 @@ export class ToolManager { codeElement.classList.add('language-json'); data.forEach(i => i.parameters = tryParse(i.parameters)); codeElement.textContent = JSON.stringify(data, null, 2); - const toolNames = data.map(i => i.name).join(', '); - summaryElement.textContent = `Performed tool calls: ${toolNames}`; + const toolNames = data.map(i => i.displayName || i.name).join(', '); + summaryElement.textContent = `Tool calls: ${toolNames}`; preElement.append(codeElement); detailsElement.append(summaryElement, preElement); return detailsElement.outerHTML; From 6cb82fc21e3fc7996b5f3c4a1e480890038e47da Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Fri, 4 Oct 2024 00:39:45 +0300 Subject: [PATCH 15/50] Add json param to /transcript --- src/endpoints/search.js | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/endpoints/search.js b/src/endpoints/search.js index 457124228..8c8ed7055 100644 --- a/src/endpoints/search.js +++ b/src/endpoints/search.js @@ -60,6 +60,7 @@ router.post('/transcript', jsonParser, async (request, response) => { const RE_XML_TRANSCRIPT = /([^<]*)<\/text>/g; const id = request.body.id; const lang = request.body.lang; + const json = request.body.json; if (!id) { console.log('Id is required for /transcript'); @@ -129,7 +130,9 @@ router.post('/transcript', jsonParser, async (request, response) => { // The text is double-encoded const transcriptText = transcript.map((line) => he.decode(he.decode(line.text))).join(' '); - return response.send(transcriptText); + return json + ? response.json({ transcript: transcriptText, html: videoPageBody }) + : response.send(transcriptText); } catch (error) { console.log(error); return response.sendStatus(500); From 777b2518bdb0153f194db79fb651c558e196fc3d Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Fri, 4 Oct 2024 01:12:12 +0300 Subject: [PATCH 16/50] Allow returning page if transcript extraction failed --- src/endpoints/search.js | 132 +++++++++++++++++++++++----------------- 1 file changed, 75 insertions(+), 57 deletions(-) diff --git a/src/endpoints/search.js b/src/endpoints/search.js index 8c8ed7055..64993589e 100644 --- a/src/endpoints/search.js +++ b/src/endpoints/search.js @@ -22,6 +22,72 @@ const visitHeaders = { 'Sec-Fetch-User': '?1', }; +/** + * Extract the transcript of a YouTube video + * @param {string} videoPageBody HTML of the video page + * @param {string} lang Language code + * @returns {Promise} Transcript text + */ +async function extractTranscript(videoPageBody, lang) { + const he = require('he'); + const RE_XML_TRANSCRIPT = /([^<]*)<\/text>/g; + const splittedHTML = videoPageBody.split('"captions":'); + + if (splittedHTML.length <= 1) { + if (videoPageBody.includes('class="g-recaptcha"')) { + throw new Error('Too many requests'); + } + if (!videoPageBody.includes('"playabilityStatus":')) { + throw new Error('Video is not available'); + } + throw new Error('Transcript not available'); + } + + const captions = (() => { + try { + return JSON.parse(splittedHTML[1].split(',"videoDetails')[0].replace('\n', '')); + } catch (e) { + return undefined; + } + })()?.['playerCaptionsTracklistRenderer']; + + if (!captions) { + throw new Error('Transcript disabled'); + } + + if (!('captionTracks' in captions)) { + throw new Error('Transcript not available'); + } + + if (lang && !captions.captionTracks.some(track => track.languageCode === lang)) { + throw new Error('Transcript not available in this language'); + } + + const transcriptURL = (lang ? captions.captionTracks.find(track => track.languageCode === lang) : captions.captionTracks[0]).baseUrl; + const transcriptResponse = await fetch(transcriptURL, { + headers: { + ...(lang && { 'Accept-Language': lang }), + 'User-Agent': visitHeaders['User-Agent'], + }, + }); + + if (!transcriptResponse.ok) { + throw new Error('Transcript request failed'); + } + + const transcriptBody = await transcriptResponse.text(); + const results = [...transcriptBody.matchAll(RE_XML_TRANSCRIPT)]; + const transcript = results.map((result) => ({ + text: result[3], + duration: parseFloat(result[2]), + offset: parseFloat(result[1]), + lang: lang ?? captions.captionTracks[0].languageCode, + })); + // The text is double-encoded + const transcriptText = transcript.map((line) => he.decode(he.decode(line.text))).join(' '); + return transcriptText; +} + router.post('/serpapi', jsonParser, async (request, response) => { try { const key = readSecret(request.user.directories, SECRET_KEYS.SERPAPI); @@ -56,8 +122,6 @@ router.post('/serpapi', jsonParser, async (request, response) => { */ router.post('/transcript', jsonParser, async (request, response) => { try { - const he = require('he'); - const RE_XML_TRANSCRIPT = /([^<]*)<\/text>/g; const id = request.body.id; const lang = request.body.lang; const json = request.body.json; @@ -75,64 +139,18 @@ router.post('/transcript', jsonParser, async (request, response) => { }); const videoPageBody = await videoPageResponse.text(); - const splittedHTML = videoPageBody.split('"captions":'); - if (splittedHTML.length <= 1) { - if (videoPageBody.includes('class="g-recaptcha"')) { - throw new Error('Too many requests'); + try { + const transcriptText = await extractTranscript(videoPageBody, lang); + return json + ? response.json({ transcript: transcriptText, html: videoPageBody }) + : response.send(transcriptText); + } catch (error) { + if (json) { + return response.json({ html: videoPageBody, transcript: '' }); } - if (!videoPageBody.includes('"playabilityStatus":')) { - throw new Error('Video is not available'); - } - throw new Error('Transcript not available'); + throw error; } - - const captions = (() => { - try { - return JSON.parse(splittedHTML[1].split(',"videoDetails')[0].replace('\n', '')); - } catch (e) { - return undefined; - } - })()?.['playerCaptionsTracklistRenderer']; - - if (!captions) { - throw new Error('Transcript disabled'); - } - - if (!('captionTracks' in captions)) { - throw new Error('Transcript not available'); - } - - if (lang && !captions.captionTracks.some(track => track.languageCode === lang)) { - throw new Error('Transcript not available in this language'); - } - - const transcriptURL = (lang ? captions.captionTracks.find(track => track.languageCode === lang) : captions.captionTracks[0]).baseUrl; - const transcriptResponse = await fetch(transcriptURL, { - headers: { - ...(lang && { 'Accept-Language': lang }), - 'User-Agent': visitHeaders['User-Agent'], - }, - }); - - if (!transcriptResponse.ok) { - throw new Error('Transcript request failed'); - } - - const transcriptBody = await transcriptResponse.text(); - const results = [...transcriptBody.matchAll(RE_XML_TRANSCRIPT)]; - const transcript = results.map((result) => ({ - text: result[3], - duration: parseFloat(result[2]), - offset: parseFloat(result[1]), - lang: lang ?? captions.captionTracks[0].languageCode, - })); - // The text is double-encoded - const transcriptText = transcript.map((line) => he.decode(he.decode(line.text))).join(' '); - - return json - ? response.json({ transcript: transcriptText, html: videoPageBody }) - : response.send(transcriptText); } catch (error) { console.log(error); return response.sendStatus(500); From 447a7fba68298cdae8becd82acb3b2f97dc785fd Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Fri, 4 Oct 2024 01:45:37 +0300 Subject: [PATCH 17/50] Only delete message if had successful tool calls --- public/script.js | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/public/script.js b/public/script.js index 4d59bde08..22ce72641 100644 --- a/public/script.js +++ b/public/script.js @@ -4417,11 +4417,11 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro } if (canPerformToolCalls && Array.isArray(streamingProcessor.toolCalls) && streamingProcessor.toolCalls.length) { - const lastMessage = chat[chat.length - 1]; - const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result); - shouldDeleteMessage && await deleteLastMessage(); const invocationResult = await ToolManager.invokeFunctionTools(streamingProcessor.toolCalls); if (invocationResult.hadToolCalls) { + const lastMessage = chat[chat.length - 1]; + const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result); + shouldDeleteMessage && await deleteLastMessage(); if (!invocationResult.invocations.length && shouldDeleteMessage) { ToolManager.showToolCallError(invocationResult.errors); unblockGeneration(type); @@ -4511,10 +4511,10 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro } if (canPerformToolCalls) { - const shouldDeleteMessage = ['', '...'].includes(getMessage); - shouldDeleteMessage && await deleteLastMessage(); const invocationResult = await ToolManager.invokeFunctionTools(data); if (invocationResult.hadToolCalls) { + const shouldDeleteMessage = ['', '...'].includes(getMessage); + shouldDeleteMessage && await deleteLastMessage(); if (!invocationResult.invocations.length && shouldDeleteMessage) { ToolManager.showToolCallError(invocationResult.errors); unblockGeneration(type); From 01f03dbf508a00240e3019c517d8f61166af9eb1 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Fri, 4 Oct 2024 01:51:41 +0300 Subject: [PATCH 18/50] Support MistralAI streaming tool calls --- public/scripts/tool-calling.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/public/scripts/tool-calling.js b/public/scripts/tool-calling.js index 0eaada477..e0e3c53e0 100644 --- a/public/scripts/tool-calling.js +++ b/public/scripts/tool-calling.js @@ -278,9 +278,9 @@ export class ToolManager { } for (const toolCallDelta of toolCallDeltas) { - const toolCallIndex = (typeof toolCallDelta?.index === 'number') ? toolCallDelta.index : null; + const toolCallIndex = (typeof toolCallDelta?.index === 'number') ? toolCallDelta.index : toolCallDeltas.indexOf(toolCallDelta); - if (toolCallIndex === null) { + if (isNaN(toolCallIndex) || toolCallIndex < 0) { continue; } From 559f1b81f79b2c4e2dc68538d29ead31ed738c36 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Fri, 4 Oct 2024 02:11:46 +0300 Subject: [PATCH 19/50] Remove tool calling for Cohere v1 --- public/scripts/tool-calling.js | 13 ---- src/endpoints/backends/chat-completions.js | 10 +-- src/prompt-converters.js | 71 ---------------------- 3 files changed, 1 insertion(+), 93 deletions(-) diff --git a/public/scripts/tool-calling.js b/public/scripts/tool-calling.js index e0e3c53e0..2f101d9ba 100644 --- a/public/scripts/tool-calling.js +++ b/public/scripts/tool-calling.js @@ -461,19 +461,6 @@ export class ToolManager { } */ - /* - if ([chat_completion_sources.COHERE].includes(oai_settings.chat_completion_source)) { - if (!Array.isArray(data?.tool_calls)) { - return; - } - - for (const toolCall of data.tool_calls) { - const args = { name: toolCall.name, arguments: JSON.stringify(toolCall.parameters) }; - console.log('Function tool call:', toolCall); - } - } - */ - return result; } diff --git a/src/endpoints/backends/chat-completions.js b/src/endpoints/backends/chat-completions.js index 55cb08c70..552ebefbd 100644 --- a/src/endpoints/backends/chat-completions.js +++ b/src/endpoints/backends/chat-completions.js @@ -4,7 +4,7 @@ const fetch = require('node-fetch').default; const { jsonParser } = require('../../express-common'); const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY, OPENROUTER_HEADERS } = require('../../constants'); const { forwardFetchResponse, getConfigValue, tryParse, uuidv4, mergeObjectWithYaml, excludeKeysByYaml, color } = require('../../util'); -const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertCohereTools, convertAI21Messages } = require('../../prompt-converters'); +const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertAI21Messages } = require('../../prompt-converters'); const CohereStream = require('../../cohere-stream'); const { readSecret, SECRET_KEYS } = require('../secrets'); @@ -555,14 +555,6 @@ async function sendCohereRequest(request, response) { }); } - /* - if (Array.isArray(request.body.tools) && request.body.tools.length > 0) { - tools.push(...convertCohereTools(request.body.tools)); - // Can't have both connectors and tools in the same request - connectors.splice(0, connectors.length); - } - */ - // https://docs.cohere.com/reference/chat const requestBody = { stream: Boolean(request.body.stream), diff --git a/src/prompt-converters.js b/src/prompt-converters.js index 5d52adb6e..2f6aaf960 100644 --- a/src/prompt-converters.js +++ b/src/prompt-converters.js @@ -522,76 +522,6 @@ function convertTextCompletionPrompt(messages) { return messageStrings.join('\n') + '\nassistant:'; } -/** - * Convert OpenAI Chat Completion tools to the format used by Cohere. - * @param {object[]} tools OpenAI Chat Completion tool definitions - */ -function convertCohereTools(tools) { - if (!Array.isArray(tools) || tools.length === 0) { - return []; - } - - const jsonSchemaToPythonTypes = { - 'string': 'str', - 'number': 'float', - 'integer': 'int', - 'boolean': 'bool', - 'array': 'list', - 'object': 'dict', - }; - - const cohereTools = []; - - for (const tool of tools) { - if (tool?.type !== 'function') { - console.log(`Unsupported tool type: ${tool.type}`); - continue; - } - - const name = tool?.function?.name; - const description = tool?.function?.description; - const properties = tool?.function?.parameters?.properties; - const required = tool?.function?.parameters?.required; - const parameters = {}; - - if (!name) { - console.log('Tool name is missing'); - continue; - } - - if (!description) { - console.log('Tool description is missing'); - } - - if (!properties || typeof properties !== 'object') { - console.log(`No properties found for tool: ${tool?.function?.name}`); - continue; - } - - for (const property in properties) { - const parameterDefinition = properties[property]; - const description = parameterDefinition.description || (parameterDefinition.enum ? JSON.stringify(parameterDefinition.enum) : ''); - const type = jsonSchemaToPythonTypes[parameterDefinition.type] || 'str'; - const isRequired = Array.isArray(required) && required.includes(property); - parameters[property] = { - description: description, - type: type, - required: isRequired, - }; - } - - const cohereTool = { - name: tool.function.name, - description: tool.function.description, - parameter_definitions: parameters, - }; - - cohereTools.push(cohereTool); - } - - return cohereTools; -} - module.exports = { convertClaudePrompt, convertClaudeMessages, @@ -599,6 +529,5 @@ module.exports = { convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, - convertCohereTools, convertAI21Messages, }; From c3c10a629e000f3d7479f9a50fbb6d28316acaa0 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Fri, 4 Oct 2024 03:41:25 +0300 Subject: [PATCH 20/50] Claude: new prompt converter + non-streaming tools --- public/scripts/tool-calling.js | 111 +++++++++--------- src/endpoints/backends/chat-completions.js | 2 - src/prompt-converters.js | 127 ++++++++++++++------- 3 files changed, 141 insertions(+), 99 deletions(-) diff --git a/public/scripts/tool-calling.js b/public/scripts/tool-calling.js index 2f101d9ba..15d3db954 100644 --- a/public/scripts/tool-calling.js +++ b/public/scripts/tool-calling.js @@ -339,10 +339,9 @@ export class ToolManager { const supportedSources = [ chat_completion_sources.OPENAI, - //chat_completion_sources.COHERE, chat_completion_sources.CUSTOM, chat_completion_sources.MISTRALAI, - //chat_completion_sources.CLAUDE, + chat_completion_sources.CLAUDE, chat_completion_sources.OPENROUTER, chat_completion_sources.GROQ, ]; @@ -372,18 +371,29 @@ export class ToolManager { } // Parsed tool calls from non-streaming data - if (!Array.isArray(data?.choices)) { - return; + if (Array.isArray(data?.choices)) { + // Find a choice with 0-index + const choice = data.choices.find(choice => choice.index === 0); + + if (choice) { + return choice.message.tool_calls; + } } - // Find a choice with 0-index - const choice = data.choices.find(choice => choice.index === 0); + if (Array.isArray(data?.content)) { + // Claude tool calls to OpenAI tool calls + const content = data.content.filter(c => c.type === 'tool_use').map(c => { + return { + id: c.id, + function: { + name: c.name, + arguments: c.input, + }, + }; + }); - if (!choice) { - return; + return content; } - - return choice.message.tool_calls; } /** @@ -407,59 +417,43 @@ export class ToolManager { chat_completion_sources.GROQ, ]; - if (oaiCompatibleSources.includes(oai_settings.chat_completion_source)) { - if (!Array.isArray(toolCalls)) { - return result; - } - - for (const toolCall of toolCalls) { - if (typeof toolCall.function !== 'object') { - continue; - } - - console.log('Function tool call:', toolCall); - const id = toolCall.id; - const parameters = toolCall.function.arguments; - const name = toolCall.function.name; - const displayName = ToolManager.getDisplayName(name); - result.hadToolCalls = true; - - const message = ToolManager.formatToolCallMessage(name, parameters); - const toast = message && toastr.info(message, 'Tool Calling', { timeOut: 0 }); - const toolResult = await ToolManager.invokeFunctionTool(name, parameters); - toastr.clear(toast); - console.log('Function tool result:', result); - - // Save a successful invocation - if (toolResult instanceof Error) { - result.errors.push(toolResult); - continue; - } - - const invocation = { - id, - displayName, - name, - parameters, - result: toolResult, - }; - result.invocations.push(invocation); - } + if (!Array.isArray(toolCalls)) { + return result; } - /* - if ([chat_completion_sources.CLAUDE].includes(oai_settings.chat_completion_source)) { - if (!Array.isArray(data?.content)) { - return; + for (const toolCall of toolCalls) { + if (typeof toolCall.function !== 'object') { + continue; } - for (const content of data.content) { - if (content.type === 'tool_use') { - const args = { name: content.name, arguments: JSON.stringify(content.input) }; - } + console.log('Function tool call:', toolCall); + const id = toolCall.id; + const parameters = toolCall.function.arguments; + const name = toolCall.function.name; + const displayName = ToolManager.getDisplayName(name); + result.hadToolCalls = true; + + const message = ToolManager.formatToolCallMessage(name, parameters); + const toast = message && toastr.info(message, 'Tool Calling', { timeOut: 0 }); + const toolResult = await ToolManager.invokeFunctionTool(name, parameters); + toastr.clear(toast); + console.log('Function tool result:', result); + + // Save a successful invocation + if (toolResult instanceof Error) { + result.errors.push(toolResult); + continue; } + + const invocation = { + id, + displayName, + name, + parameters, + result: toolResult, + }; + result.invocations.push(invocation); } - */ return result; } @@ -491,6 +485,9 @@ export class ToolManager { * @param {ToolInvocation[]} invocations Successful tool invocations */ static saveFunctionToolInvocations(invocations) { + if (!Array.isArray(invocations) || invocations.length === 0) { + return; + } const message = { name: systemUserName, force_avatar: system_avatar, diff --git a/src/endpoints/backends/chat-completions.js b/src/endpoints/backends/chat-completions.js index 552ebefbd..9c29a3b44 100644 --- a/src/endpoints/backends/chat-completions.js +++ b/src/endpoints/backends/chat-completions.js @@ -124,7 +124,6 @@ async function sendClaudeRequest(request, response) { } else { delete requestBody.system; } - /* if (Array.isArray(request.body.tools) && request.body.tools.length > 0) { // Claude doesn't do prefills on function calls, and doesn't allow empty messages if (convertedPrompt.messages.length && convertedPrompt.messages[convertedPrompt.messages.length - 1].role === 'assistant') { @@ -137,7 +136,6 @@ async function sendClaudeRequest(request, response) { .map(tool => tool.function) .map(fn => ({ name: fn.name, description: fn.description, input_schema: fn.parameters })); } - */ if (enableSystemPromptCache) { additionalHeaders['anthropic-beta'] = 'prompt-caching-2024-07-31'; } diff --git a/src/prompt-converters.js b/src/prompt-converters.js index 2f6aaf960..b4b3d718b 100644 --- a/src/prompt-converters.js +++ b/src/prompt-converters.js @@ -118,8 +118,27 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFi }); } } + // Now replace all further messages that have the role 'system' with the role 'user'. (or all if we're not using one) messages.forEach((message) => { + if (message.role === 'assistant' && message.tool_calls) { + message.content = message.tool_calls.map((tc) => ({ + type: 'tool_use', + id: tc.id, + name: tc.function.name, + input: tc.function.arguments, + })); + } + + if (message.role === 'tool') { + message.role = 'user'; + message.content = [{ + type: 'tool_result', + tool_use_id: message.tool_call_id, + content: message.content, + }]; + } + if (message.role === 'system') { if (userName && message.name === 'example_user') { message.content = `${userName}: ${message.content}`; @@ -128,13 +147,80 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFi message.content = `${charName}: ${message.content}`; } message.role = 'user'; + + // Delete name here so it doesn't get added later + delete message.name; } + + // Convert everything to an array of it would be easier to work with + if (typeof message.content === 'string') { + // Take care of name properties since claude messages don't support them + if (message.name) { + message.content = `${message.name}: ${message.content}`; + } + + message.content = [{ type: 'text', text: message.content }]; + } else if (Array.isArray(message.content)) { + message.content = message.content.map((content) => { + if (content.type === 'image_url') { + const imageEntry = content?.image_url; + const imageData = imageEntry?.url; + const mimeType = imageData?.split(';')?.[0].split(':')?.[1]; + const base64Data = imageData?.split(',')?.[1]; + + return { + type: 'image', + source: { + type: 'base64', + media_type: mimeType, + data: base64Data, + }, + }; + } + + if (content.type === 'text') { + if (message.name) { + content.text = `${message.name}: ${content.text}`; + } + + return content; + } + + return content; + }); + } + + // Remove offending properties + delete message.name; + delete message.tool_calls; + delete message.tool_call_id; }); + // Images in assistant messages should be moved to the next user message + for (let i = 0; i < messages.length; i++) { + if (messages[i].role === 'assistant' && messages[i].content.some(c => c.type === 'image')) { + // Find the next user message + let j = i + 1; + while (j < messages.length && messages[j].role !== 'user') { + j++; + } + + // Move the images + if (j >= messages.length) { + // If there is no user message after the assistant message, add a new one + messages.splice(i + 1, 0, { role: 'user', content: [] }); + } + + messages[j].content.push(...messages[i].content.filter(c => c.type === 'image')); + messages[i].content = messages[i].content.filter(c => c.type !== 'image'); + } + } + // Shouldn't be conditional anymore, messages api expects the last role to be user unless we're explicitly prefilling if (prefillString) { messages.push({ role: 'assistant', + // Dangling whitespace are not allowed for prefilling content: prefillString.trimEnd(), }); } @@ -143,50 +229,11 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFi // Also handle multi-modality, holy slop. let mergedMessages = []; messages.forEach((message) => { - const imageEntry = message.content?.[1]?.image_url; - const imageData = imageEntry?.url; - const mimeType = imageData?.split(';')?.[0].split(':')?.[1]; - const base64Data = imageData?.split(',')?.[1]; - - // Take care of name properties since claude messages don't support them - if (message.name) { - if (Array.isArray(message.content)) { - message.content[0].text = `${message.name}: ${message.content[0].text}`; - } else { - message.content = `${message.name}: ${message.content}`; - } - delete message.name; - } - if (mergedMessages.length > 0 && mergedMessages[mergedMessages.length - 1].role === message.role) { - if (Array.isArray(message.content)) { - if (Array.isArray(mergedMessages[mergedMessages.length - 1].content)) { - mergedMessages[mergedMessages.length - 1].content[0].text += '\n\n' + message.content[0].text; - } else { - mergedMessages[mergedMessages.length - 1].content += '\n\n' + message.content[0].text; - } - } else { - if (Array.isArray(mergedMessages[mergedMessages.length - 1].content)) { - mergedMessages[mergedMessages.length - 1].content[0].text += '\n\n' + message.content; - } else { - mergedMessages[mergedMessages.length - 1].content += '\n\n' + message.content; - } - } + mergedMessages[mergedMessages.length - 1].content.push(...message.content); } else { mergedMessages.push(message); } - if (imageData) { - mergedMessages[mergedMessages.length - 1].content = [ - { type: 'text', text: mergedMessages[mergedMessages.length - 1].content[0]?.text || mergedMessages[mergedMessages.length - 1].content }, - { - type: 'image', source: { - type: 'base64', - media_type: mimeType, - data: base64Data, - }, - }, - ]; - } }); return { messages: mergedMessages, systemPrompt: systemPrompt.trim() }; From 4a2989718caf6f8a64515035c01a43f26b38cbe3 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Fri, 4 Oct 2024 10:34:17 +0000 Subject: [PATCH 21/50] ESLint and JSDoc fixes --- public/script.js | 2 ++ public/scripts/tool-calling.js | 42 ++++++++++++++++++++-------------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/public/script.js b/public/script.js index 22ce72641..109c6cc36 100644 --- a/public/script.js +++ b/public/script.js @@ -8180,6 +8180,8 @@ window['SillyTavern'].getContext = function () { unregisterMacro: MacrosParser.unregisterMacro.bind(MacrosParser), registerFunctionTool: ToolManager.registerFunctionTool.bind(ToolManager), unregisterFunctionTool: ToolManager.unregisterFunctionTool.bind(ToolManager), + isToolCallingSupported: ToolManager.isToolCallingSupported.bind(ToolManager), + canPerformToolCalls: ToolManager.canPerformToolCalls.bind(ToolManager), registerDebugFunction: registerDebugFunction, /** @deprecated Use renderExtensionTemplateAsync instead. */ renderExtensionTemplate: renderExtensionTemplate, diff --git a/public/scripts/tool-calling.js b/public/scripts/tool-calling.js index 15d3db954..7457392b3 100644 --- a/public/scripts/tool-calling.js +++ b/public/scripts/tool-calling.js @@ -18,6 +18,16 @@ import { Popup } from './popup.js'; * @property {Error[]} errors Errors that occurred during tool invocation */ +/** + * @typedef {object} ToolRegistration + * @property {string} name - The name of the tool. + * @property {string} displayName - The display name of the tool. + * @property {string} description - A description of the tool. + * @property {object} parameters - The parameters for the tool. + * @property {function} action - The action to perform when the tool is invoked. + * @property {function} formatMessage - A function to format the tool call message. + */ + /** * A class that represents a tool definition. */ @@ -136,13 +146,7 @@ export class ToolManager { /** * Registers a new tool with the tool registry. - * @param {object} tool The tool to register. - * @param {string} tool.name The name of the tool. - * @param {string} tool.displayName A user-friendly display name for the tool. - * @param {string} tool.description A description of what the tool does. - * @param {object} tool.parameters A JSON schema for the parameters that the tool accepts. - * @param {function} tool.action A function that will be called when the tool is executed. - * @param {function} tool.formatMessage A function that will be called to format the tool call toast. + * @param {ToolRegistration} tool The tool to register. */ static registerFunctionTool({ name, displayName, description, parameters, action, formatMessage }) { // Convert WIP arguments @@ -201,6 +205,12 @@ export class ToolManager { } } + /** + * Formats a message for a tool call by name. + * @param {string} name The name of the tool to format the message for. + * @param {object} parameters Function tool call parameters. + * @returns {string} The formatted message for the tool call. + */ static formatToolCallMessage(name, parameters) { if (!this.#tools.has(name)) { return `Invoked unknown tool: ${name}`; @@ -295,9 +305,14 @@ export class ToolManager { } } + /** + * Apply a tool call delta to a target object. + * @param {object} target The target object to apply the delta to + * @param {object} delta The delta object to apply + */ static #applyToolCallDelta(target, delta) { for (const key in delta) { - if (!delta.hasOwnProperty(key)) continue; + if (!Object.prototype.hasOwnProperty.call(delta, key)) continue; if (key === '__proto__' || key === 'constructor') continue; const deltaValue = delta[key]; @@ -409,13 +424,6 @@ export class ToolManager { errors: [], }; const toolCalls = ToolManager.#getToolCallsFromData(data); - const oaiCompatibleSources = [ - chat_completion_sources.OPENAI, - chat_completion_sources.CUSTOM, - chat_completion_sources.MISTRALAI, - chat_completion_sources.OPENROUTER, - chat_completion_sources.GROQ, - ]; if (!Array.isArray(toolCalls)) { return result; @@ -463,7 +471,7 @@ export class ToolManager { * @param {ToolInvocation[]} invocations Tool invocations. * @returns {string} Formatted message with tool invocations. */ - static #formatMessage(invocations) { + static #formatToolInvocationMessage(invocations) { const tryParse = (x) => { try { return JSON.parse(x); } catch { return x; } }; const data = structuredClone(invocations); const detailsElement = document.createElement('details'); @@ -493,7 +501,7 @@ export class ToolManager { force_avatar: system_avatar, is_system: true, is_user: false, - mes: ToolManager.#formatMessage(invocations), + mes: ToolManager.#formatToolInvocationMessage(invocations), extra: { isSmallSys: true, tool_invocations: invocations, From 0cab91e0f8c0074d4bfdf33c3971b3dde1029412 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Fri, 4 Oct 2024 13:39:08 +0300 Subject: [PATCH 22/50] Add Claude streamed tool parser --- public/scripts/tool-calling.js | 69 +++++++++++++++++++++------------- 1 file changed, 43 insertions(+), 26 deletions(-) diff --git a/public/scripts/tool-calling.js b/public/scripts/tool-calling.js index 15d3db954..57e21b618 100644 --- a/public/scripts/tool-calling.js +++ b/public/scripts/tool-calling.js @@ -256,41 +256,58 @@ export class ToolManager { * @returns {void} */ static parseToolCalls(toolCalls, parsed) { - if (!Array.isArray(parsed?.choices)) { - return; - } - for (const choice of parsed.choices) { - const choiceIndex = (typeof choice.index === 'number') ? choice.index : null; - const choiceDelta = choice.delta; + if (Array.isArray(parsed?.choices)) { + for (const choice of parsed.choices) { + const choiceIndex = (typeof choice.index === 'number') ? choice.index : null; + const choiceDelta = choice.delta; - if (choiceIndex === null || !choiceDelta) { - continue; - } - - const toolCallDeltas = choiceDelta?.tool_calls; - - if (!Array.isArray(toolCallDeltas)) { - continue; - } - - if (!Array.isArray(toolCalls[choiceIndex])) { - toolCalls[choiceIndex] = []; - } - - for (const toolCallDelta of toolCallDeltas) { - const toolCallIndex = (typeof toolCallDelta?.index === 'number') ? toolCallDelta.index : toolCallDeltas.indexOf(toolCallDelta); - - if (isNaN(toolCallIndex) || toolCallIndex < 0) { + if (choiceIndex === null || !choiceDelta) { continue; } + const toolCallDeltas = choiceDelta?.tool_calls; + + if (!Array.isArray(toolCallDeltas)) { + continue; + } + + if (!Array.isArray(toolCalls[choiceIndex])) { + toolCalls[choiceIndex] = []; + } + + for (const toolCallDelta of toolCallDeltas) { + const toolCallIndex = (typeof toolCallDelta?.index === 'number') ? toolCallDelta.index : toolCallDeltas.indexOf(toolCallDelta); + + if (isNaN(toolCallIndex) || toolCallIndex < 0) { + continue; + } + + if (toolCalls[choiceIndex][toolCallIndex] === undefined) { + toolCalls[choiceIndex][toolCallIndex] = {}; + } + + const targetToolCall = toolCalls[choiceIndex][toolCallIndex]; + + ToolManager.#applyToolCallDelta(targetToolCall, toolCallDelta); + } + } + } + if (typeof parsed?.content_block === 'object') { + const choiceIndex = 0; + + if (parsed?.content_block?.type === 'tool_use') { + if (!Array.isArray(toolCalls[choiceIndex])) { + toolCalls[choiceIndex] = []; + } + + const toolCallIndex = toolCalls[choiceIndex].length; + if (toolCalls[choiceIndex][toolCallIndex] === undefined) { toolCalls[choiceIndex][toolCallIndex] = {}; } const targetToolCall = toolCalls[choiceIndex][toolCallIndex]; - - ToolManager.#applyToolCallDelta(targetToolCall, toolCallDelta); + Object.assign(targetToolCall, parsed.content_block); } } } From db04fff3df342e7c035c7dccccb88df5dd2fc75c Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Fri, 4 Oct 2024 11:31:15 +0000 Subject: [PATCH 23/50] Claude: Streamed tool calls parser --- public/scripts/tool-calling.js | 63 ++++++++++++++++++++++++---------- 1 file changed, 45 insertions(+), 18 deletions(-) diff --git a/public/scripts/tool-calling.js b/public/scripts/tool-calling.js index cd989cc0b..264649753 100644 --- a/public/scripts/tool-calling.js +++ b/public/scripts/tool-calling.js @@ -136,6 +136,8 @@ export class ToolManager { */ static #tools = new Map(); + static #INPUT_DELTA_KEY = '__input_json_delta'; + /** * Returns an Array of all tools that have been registered. * @type {ToolDefinition[]} @@ -304,20 +306,48 @@ export class ToolManager { } if (typeof parsed?.content_block === 'object') { const choiceIndex = 0; + const toolCallIndex = parsed?.index ?? 0; if (parsed?.content_block?.type === 'tool_use') { if (!Array.isArray(toolCalls[choiceIndex])) { toolCalls[choiceIndex] = []; } - - const toolCallIndex = toolCalls[choiceIndex].length; - if (toolCalls[choiceIndex][toolCallIndex] === undefined) { toolCalls[choiceIndex][toolCallIndex] = {}; } - const targetToolCall = toolCalls[choiceIndex][toolCallIndex]; - Object.assign(targetToolCall, parsed.content_block); + ToolManager.#applyToolCallDelta(targetToolCall, parsed.content_block); + } + } + if (typeof parsed?.delta === 'object') { + const choiceIndex = 0; + const toolCallIndex = parsed?.index ?? 0; + const targetToolCall = toolCalls[choiceIndex]?.[toolCallIndex]; + if (targetToolCall){ + if (parsed?.delta?.type === 'input_json_delta') { + const jsonDelta = parsed?.delta?.partial_json; + if (!targetToolCall[this.#INPUT_DELTA_KEY]) { + targetToolCall[this.#INPUT_DELTA_KEY] = ''; + } + targetToolCall[this.#INPUT_DELTA_KEY] += jsonDelta; + } + } + } + if (parsed?.type === 'content_block_stop') { + const choiceIndex = 0; + const toolCallIndex = parsed?.index ?? 0; + const targetToolCall = toolCalls[choiceIndex]?.[toolCallIndex]; + if (targetToolCall) { + const jsonDeltaString = targetToolCall[this.#INPUT_DELTA_KEY]; + if (jsonDeltaString) { + try { + const jsonDelta = { input: JSON.parse(jsonDeltaString) }; + delete targetToolCall[this.#INPUT_DELTA_KEY]; + ToolManager.#applyToolCallDelta(targetToolCall, jsonDelta); + } catch (error) { + console.warn('Failed to apply input JSON delta:', error); + } + } } } } @@ -397,9 +427,12 @@ export class ToolManager { * @returns {any[]} Tool calls from the response data */ static #getToolCallsFromData(data) { + const isClaudeToolCall = c => Array.isArray(c) ? c.filter(x => x).every(isClaudeToolCall) : c?.input && c?.name && c?.id; + const convertClaudeToolCall = c => ({ id: c.id, function: { name: c.name, arguments: c.input } }); + // Parsed tool calls from streaming data - if (Array.isArray(data) && data.length > 0) { - return data[0]; + if (Array.isArray(data) && data.length > 0 && Array.isArray(data[0])) { + return isClaudeToolCall(data[0]) ? data[0].filter(x => x).map(convertClaudeToolCall) : data[0]; } // Parsed tool calls from non-streaming data @@ -412,19 +445,13 @@ export class ToolManager { } } + // Claude tool calls to OpenAI tool calls if (Array.isArray(data?.content)) { - // Claude tool calls to OpenAI tool calls - const content = data.content.filter(c => c.type === 'tool_use').map(c => { - return { - id: c.id, - function: { - name: c.name, - arguments: c.input, - }, - }; - }); + const content = data.content.filter(c => c.type === 'tool_use').map(convertClaudeToolCall); - return content; + if (content) { + return content; + } } } From 8c095f204a014c5e5f7f784b1f203c6b5089b1d7 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Fri, 4 Oct 2024 12:00:20 +0000 Subject: [PATCH 24/50] Fix error on streaming if the processor was already destroyed --- public/script.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/public/script.js b/public/script.js index 109c6cc36..f34d57cf8 100644 --- a/public/script.js +++ b/public/script.js @@ -4420,7 +4420,7 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro const invocationResult = await ToolManager.invokeFunctionTools(streamingProcessor.toolCalls); if (invocationResult.hadToolCalls) { const lastMessage = chat[chat.length - 1]; - const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result); + const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor?.result); shouldDeleteMessage && await deleteLastMessage(); if (!invocationResult.invocations.length && shouldDeleteMessage) { ToolManager.showToolCallError(invocationResult.errors); From cc3cc58a06a7ecfd841e4ddad8f7190c61da5d98 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Fri, 4 Oct 2024 12:24:10 +0000 Subject: [PATCH 25/50] Claude: fix token counting when tool_calls are used --- src/prompt-converters.js | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/prompt-converters.js b/src/prompt-converters.js index b4b3d718b..ceb3c1189 100644 --- a/src/prompt-converters.js +++ b/src/prompt-converters.js @@ -19,6 +19,14 @@ function convertClaudePrompt(messages, addAssistantPostfix, addAssistantPrefill, //Prepare messages for claude. //When 'Exclude Human/Assistant prefixes' checked, setting messages role to the 'system'(last message is exception). if (messages.length > 0) { + messages.forEach((m) => { + if (!m.content) { + m.content = ''; + } + if (m.tool_calls) { + m.content += JSON.stringify(m.tool_calls); + } + }); if (excludePrefixes) { messages.slice(0, -1).forEach(message => message.role = 'system'); } else { From c853547b11a8d219fb9c3c04bd05737dc56678e7 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Fri, 4 Oct 2024 13:04:19 +0000 Subject: [PATCH 26/50] Add a function tool for image generation --- .../extensions/stable-diffusion/index.js | 54 +++++++++++++++++++ .../extensions/stable-diffusion/settings.html | 4 ++ 2 files changed, 58 insertions(+) diff --git a/public/scripts/extensions/stable-diffusion/index.js b/public/scripts/extensions/stable-diffusion/index.js index e96575d56..0289ff807 100644 --- a/public/scripts/extensions/stable-diffusion/index.js +++ b/public/scripts/extensions/stable-diffusion/index.js @@ -32,6 +32,7 @@ import { debounce_timeout } from '../../constants.js'; import { SlashCommandEnumValue } from '../../slash-commands/SlashCommandEnumValue.js'; import { POPUP_RESULT, POPUP_TYPE, Popup, callGenericPopup } from '../../popup.js'; import { commonEnumProviders } from '../../slash-commands/SlashCommandCommonEnumsProvider.js'; +import { ToolManager } from '../../tool-calling.js'; export { MODULE_NAME }; const MODULE_NAME = 'sd'; @@ -62,6 +63,7 @@ const initiators = { interactive: 'interactive', wand: 'wand', swipe: 'swipe', + tool: 'tool', }; const generationMode = { @@ -226,6 +228,7 @@ const defaultSettings = { multimodal_captioning: false, snap: false, free_extend: false, + function_tool: false, prompts: promptTemplates, @@ -291,6 +294,10 @@ const defaultSettings = { const writePromptFieldsDebounced = debounce(writePromptFields, debounce_timeout.relaxed); function processTriggers(chat, _, abort) { + if (extension_settings.sd.function_tool && ToolManager.isToolCallingSupported()) { + return; + } + if (!extension_settings.sd.interactive_mode) { return; } @@ -447,6 +454,7 @@ async function loadSettings() { $('#sd_interactive_visible').prop('checked', extension_settings.sd.interactive_visible); $('#sd_stability_style_preset').val(extension_settings.sd.stability_style_preset); $('#sd_huggingface_model_id').val(extension_settings.sd.huggingface_model_id); + $('#sd_function_tool').prop('checked', extension_settings.sd.function_tool); for (const style of extension_settings.sd.styles) { const option = document.createElement('option'); @@ -461,6 +469,7 @@ async function loadSettings() { toggleSourceControls(); addPromptTemplates(); + registerFunctionTool(); await loadSettingOptions(); } @@ -910,6 +919,12 @@ async function onSourceChange() { await loadSettingOptions(); } +function onFunctionToolInput() { + extension_settings.sd.function_tool = !!$(this).prop('checked'); + saveSettingsDebounced(); + registerFunctionTool(); +} + async function onOpenAiStyleSelect() { extension_settings.sd.openai_style = String($('#sd_openai_style').find(':selected').val()); saveSettingsDebounced(); @@ -3822,6 +3837,44 @@ function applyCommandArguments(args) { return currentSettings; } +function registerFunctionTool() { + if (!extension_settings.sd.function_tool) { + return ToolManager.unregisterFunctionTool('GenerateImage'); + } + + ToolManager.registerFunctionTool({ + name: 'GenerateImage', + displayName: 'Generate Image', + description: [ + 'Generate an image from a given text prompt.', + 'Use when a user asks for an image, a selfie, to picture a scene, etc.', + ].join(' '), + parameters: Object.freeze({ + $schema: 'http://json-schema.org/draft-04/schema#', + type: 'object', + properties: { + prompt: { + type: 'string', + description: [ + 'The text prompt used to generate the image.', + 'Must represent an exhaustive description of the desired image that will allow an artist or a photographer to perfectly recreate it.', + ], + }, + }, + required: [ + 'prompt', + ], + }), + action: async (args) => { + if (!isValidState()) throw new Error('Image generation is not configured.'); + if (!args) throw new Error('Missing arguments'); + if (!args.prompt) throw new Error('Missing prompt'); + return generatePicture(initiators.tool, {}, args.prompt); + }, + formatMessage: () => 'Generating an image...', + }); +} + jQuery(async () => { await addSDGenButtons(); @@ -4175,6 +4228,7 @@ jQuery(async () => { $('#sd_stability_key').on('click', onStabilityKeyClick); $('#sd_stability_style_preset').on('change', onStabilityStylePresetChange); $('#sd_huggingface_model_id').on('input', onHFModelInput); + $('#sd_function_tool').on('input', onFunctionToolInput); if (!CSS.supports('field-sizing', 'content')) { $('.sd_settings .inline-drawer-toggle').on('click', function () { diff --git a/public/scripts/extensions/stable-diffusion/settings.html b/public/scripts/extensions/stable-diffusion/settings.html index 2efa74d83..6870facd2 100644 --- a/public/scripts/extensions/stable-diffusion/settings.html +++ b/public/scripts/extensions/stable-diffusion/settings.html @@ -18,6 +18,10 @@ Interactive mode +