diff --git a/public/global.d.ts b/public/global.d.ts index c8bfe14c2..39ba4ea0c 100644 --- a/public/global.d.ts +++ b/public/global.d.ts @@ -1365,44 +1365,3 @@ declare namespace moment { declare global { const moment: typeof moment; } - -/** - * Callback data for the `LLM_FUNCTION_TOOL_REGISTER` event type that is triggered when a function tool can be registered. - */ -interface FunctionToolRegister { - /** - * The type of generation that is being used - */ - type?: string; - /** - * Generation data, including messages and sampling parameters - */ - data: Record; - /** - * Callback to register an LLM function tool. - */ - registerFunctionTool: typeof registerFunctionTool; -} - -/** - * Callback data for the `LLM_FUNCTION_TOOL_REGISTER` event type that is triggered when a function tool is registered. - * @param name Name of the function tool to register - * @param description Description of the function tool - * @param params JSON schema for the parameters of the function tool - * @param required Whether the function tool should be forced to be used - */ -declare function registerFunctionTool(name: string, description: string, params: object, required: boolean): Promise; - -/** - * Callback data for the `LLM_FUNCTION_TOOL_CALL` event type that is triggered when a function tool is called. - */ -interface FunctionToolCall { - /** - * Name of the function tool to call - */ - name: string; - /** - * JSON object with the parameters to pass to the function tool - */ - arguments: string; -} diff --git a/public/script.js b/public/script.js index 35bfcbab0..fac11b8f4 100644 --- a/public/script.js +++ b/public/script.js @@ -246,6 +246,7 @@ import { initInputMarkdown } from './scripts/input-md-formatting.js'; import { AbortReason } from './scripts/util/AbortReason.js'; import { initSystemPrompts } from './scripts/sysprompt.js'; import { registerExtensionSlashCommands as initExtensionSlashCommands } from './scripts/extensions-slashcommands.js'; +import { ToolManager } from './scripts/tool-calling.js'; //exporting functions and vars for mods export { @@ -463,8 +464,6 @@ export const event_types = { FILE_ATTACHMENT_DELETED: 'file_attachment_deleted', WORLDINFO_FORCE_ACTIVATE: 'worldinfo_force_activate', OPEN_CHARACTER_LIBRARY: 'open_character_library', - LLM_FUNCTION_TOOL_REGISTER: 'llm_function_tool_register', - LLM_FUNCTION_TOOL_CALL: 'llm_function_tool_call', ONLINE_STATUS_CHANGED: 'online_status_changed', IMAGE_SWIPED: 'image_swiped', CONNECTION_PROFILE_LOADED: 'connection_profile_loaded', @@ -2921,6 +2920,7 @@ class StreamingProcessor { this.swipes = []; /** @type {import('./scripts/logprobs.js').TokenLogprobs[]} */ this.messageLogprobs = []; + this.toolCalls = []; } #checkDomElements(messageId) { @@ -3139,7 +3139,7 @@ class StreamingProcessor { } /** - * @returns {Generator<{ text: string, swipes: string[], logprobs: import('./scripts/logprobs.js').TokenLogprobs }, void, void>} + * @returns {Generator<{ text: string, swipes: string[], logprobs: import('./scripts/logprobs.js').TokenLogprobs, toolCalls: any[] }, void, void>} */ *nullStreamingGeneration() { throw new Error('Generation function for streaming is not hooked up'); @@ -3161,12 +3161,13 @@ class StreamingProcessor { try { const sw = new Stopwatch(1000 / power_user.streaming_fps); const timestamps = []; - for await (const { text, swipes, logprobs } of this.generator()) { + for await (const { text, swipes, logprobs, toolCalls } of this.generator()) { timestamps.push(Date.now()); if (this.isStopped) { return; } + this.toolCalls = toolCalls; this.result = text; this.swipes = Array.from(swipes ?? []); if (logprobs) { @@ -4405,6 +4406,20 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro getMessage = continue_mag + getMessage; } + if (ToolManager.isFunctionCallingSupported() && Array.isArray(streamingProcessor.toolCalls) && streamingProcessor.toolCalls.length) { + const invocations = await ToolManager.checkFunctionToolCalls(streamingProcessor.toolCalls); + if (invocations.length) { + const lastMessage = chat[chat.length - 1]; + const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result); + if (shouldDeleteMessage) { + await deleteLastMessage(); + streamingProcessor = null; + } + ToolManager.saveFunctionToolInvocations(invocations); + return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun); + } + } + if (streamingProcessor && !streamingProcessor.isStopped && streamingProcessor.isFinished) { await streamingProcessor.onFinishStreaming(streamingProcessor.messageId, getMessage); streamingProcessor = null; @@ -4440,6 +4455,14 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro throw new Error(data?.response); } + if (ToolManager.isFunctionCallingSupported()) { + const invocations = await ToolManager.checkFunctionToolCalls(data); + if (invocations.length) { + ToolManager.saveFunctionToolInvocations(invocations); + return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun); + } + } + //const getData = await response.json(); let getMessage = extractMessageFromData(data); let title = extractTitleFromData(data); @@ -7853,7 +7876,7 @@ function openAlternateGreetings() { if (menu_type !== 'create') { await createOrEditCharacter(); } - } + }, }); for (let index = 0; index < getArray().length; index++) { @@ -8130,6 +8153,8 @@ window['SillyTavern'].getContext = function () { registerHelper: () => { }, registerMacro: MacrosParser.registerMacro.bind(MacrosParser), unregisterMacro: MacrosParser.unregisterMacro.bind(MacrosParser), + registerFunctionTool: ToolManager.registerFunctionTool.bind(ToolManager), + unregisterFunctionTool: ToolManager.unregisterFunctionTool.bind(ToolManager), registerDebugFunction: registerDebugFunction, /** @deprecated Use renderExtensionTemplateAsync instead. */ renderExtensionTemplate: renderExtensionTemplate, diff --git a/public/scripts/extensions/expressions/index.js b/public/scripts/extensions/expressions/index.js index 4a09e3aeb..078375cf1 100644 --- a/public/scripts/extensions/expressions/index.js +++ b/public/scripts/extensions/expressions/index.js @@ -9,7 +9,6 @@ import { debounce_timeout } from '../../constants.js'; import { SlashCommandParser } from '../../slash-commands/SlashCommandParser.js'; import { SlashCommand } from '../../slash-commands/SlashCommand.js'; import { ARGUMENT_TYPE, SlashCommandArgument, SlashCommandNamedArgument } from '../../slash-commands/SlashCommandArgument.js'; -import { isFunctionCallingSupported } from '../../openai.js'; import { SlashCommandEnumValue, enumTypes } from '../../slash-commands/SlashCommandEnumValue.js'; import { commonEnumProviders } from '../../slash-commands/SlashCommandCommonEnumsProvider.js'; import { slashCommandReturnHelper } from '../../slash-commands/SlashCommandReturnHelper.js'; @@ -21,7 +20,6 @@ const UPDATE_INTERVAL = 2000; const STREAMING_UPDATE_INTERVAL = 10000; const TALKINGCHECK_UPDATE_INTERVAL = 500; const DEFAULT_FALLBACK_EXPRESSION = 'joy'; -const FUNCTION_NAME = 'set_emotion'; const DEFAULT_LLM_PROMPT = 'Ignore previous instructions. Classify the emotion of the last message. Output just one word, e.g. "joy" or "anger". Choose only one of the following labels: {{labels}}'; const DEFAULT_EXPRESSIONS = [ 'talkinghead', @@ -1017,10 +1015,6 @@ async function getLlmPrompt(labels) { return ''; } - if (isFunctionCallingSupported()) { - return ''; - } - const labelsString = labels.map(x => `"${x}"`).join(', '); const prompt = substituteParamsExtended(String(extension_settings.expressions.llmPrompt), { labels: labelsString }); return prompt; @@ -1056,41 +1050,6 @@ function parseLlmResponse(emotionResponse, labels) { throw new Error('Could not parse emotion response ' + emotionResponse); } -/** - * Registers the function tool for the LLM API. - * @param {FunctionToolRegister} args Function tool register arguments. - */ -function onFunctionToolRegister(args) { - if (inApiCall && extension_settings.expressions.api === EXPRESSION_API.llm && isFunctionCallingSupported()) { - // Only trigger on quiet mode - if (args.type !== 'quiet') { - return; - } - - const emotions = DEFAULT_EXPRESSIONS.filter((e) => e != 'talkinghead'); - const jsonSchema = { - $schema: 'http://json-schema.org/draft-04/schema#', - type: 'object', - properties: { - emotion: { - type: 'string', - enum: emotions, - description: `One of the following: ${JSON.stringify(emotions)}`, - }, - }, - required: [ - 'emotion', - ], - }; - args.registerFunctionTool( - FUNCTION_NAME, - substituteParams('Sets the label that best describes the current emotional state of {{char}}. Only select one of the enumerated values.'), - jsonSchema, - true, - ); - } -} - function onTextGenSettingsReady(args) { // Only call if inside an API call if (inApiCall && extension_settings.expressions.api === EXPRESSION_API.llm && isJsonSchemaSupported()) { @@ -1164,18 +1123,9 @@ export async function getExpressionLabel(text, expressionsApi = extension_settin const expressionsList = await getExpressionsList(); const prompt = substituteParamsExtended(customPrompt, { labels: expressionsList }) || await getLlmPrompt(expressionsList); - let functionResult = null; eventSource.once(event_types.TEXT_COMPLETION_SETTINGS_READY, onTextGenSettingsReady); - eventSource.once(event_types.LLM_FUNCTION_TOOL_REGISTER, onFunctionToolRegister); - eventSource.once(event_types.LLM_FUNCTION_TOOL_CALL, (/** @type {FunctionToolCall} */ args) => { - if (args.name !== FUNCTION_NAME) { - return; - } - - functionResult = args?.arguments; - }); const emotionResponse = await generateRaw(text, main_api, false, false, prompt); - return parseLlmResponse(functionResult || emotionResponse, expressionsList); + return parseLlmResponse(emotionResponse, expressionsList); } // Extras default: { diff --git a/public/scripts/openai.js b/public/scripts/openai.js index 20c141316..a6048dce5 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -70,6 +70,7 @@ import { renderTemplateAsync } from './templates.js'; import { SlashCommandEnumValue } from './slash-commands/SlashCommandEnumValue.js'; import { Popup, POPUP_RESULT } from './popup.js'; import { t } from './i18n.js'; +import { ToolManager } from './tool-calling.js'; export { openai_messages_count, @@ -1863,8 +1864,8 @@ async function sendOpenAIRequest(type, messages, signal) { generate_data['seed'] = oai_settings.seed; } - if (isFunctionCallingSupported() && !stream) { - await registerFunctionTools(type, generate_data); + if (!canMultiSwipe && ToolManager.isFunctionCallingSupported()) { + await ToolManager.registerFunctionToolsOpenAI(generate_data); } if (isOAI && oai_settings.openai_model.startsWith('o1-')) { @@ -1911,6 +1912,7 @@ async function sendOpenAIRequest(type, messages, signal) { return async function* streamData() { let text = ''; const swipes = []; + const toolCalls = []; while (true) { const { done, value } = await reader.read(); if (done) return; @@ -1926,7 +1928,9 @@ async function sendOpenAIRequest(type, messages, signal) { text += getStreamingReply(parsed); } - yield { text, swipes: swipes, logprobs: parseChatCompletionLogprobs(parsed) }; + ToolManager.parseToolCalls(toolCalls, parsed); + + yield { text, swipes: swipes, logprobs: parseChatCompletionLogprobs(parsed), toolCalls: toolCalls }; } }; } @@ -1948,147 +1952,10 @@ async function sendOpenAIRequest(type, messages, signal) { delay(1).then(() => saveLogprobsForActiveMessage(logprobs, null)); } - if (isFunctionCallingSupported()) { - await checkFunctionToolCalls(data); - } - return data; } } -/** - * Register function tools for the next chat completion request. - * @param {string} type Generation type - * @param {object} data Generation data - */ -async function registerFunctionTools(type, data) { - let toolChoice = 'auto'; - const tools = []; - - /** - * @type {registerFunctionTool} - */ - const registerFunctionTool = (name, description, parameters, required) => { - tools.push({ - type: 'function', - function: { - name, - description, - parameters, - }, - }); - - if (required) { - toolChoice = 'required'; - } - }; - - /** - * @type {FunctionToolRegister} - */ - const args = { - type, - data, - registerFunctionTool, - }; - - await eventSource.emit(event_types.LLM_FUNCTION_TOOL_REGISTER, args); - - if (tools.length) { - console.log('Registered function tools:', tools); - - data['tools'] = tools; - data['tool_choice'] = toolChoice; - } -} - -async function checkFunctionToolCalls(data) { - const oaiCompat = [ - chat_completion_sources.OPENAI, - chat_completion_sources.CUSTOM, - chat_completion_sources.MISTRALAI, - chat_completion_sources.OPENROUTER, - chat_completion_sources.GROQ, - ]; - if (oaiCompat.includes(oai_settings.chat_completion_source)) { - if (!Array.isArray(data?.choices)) { - return; - } - - // Find a choice with 0-index - const choice = data.choices.find(choice => choice.index === 0); - - if (!choice) { - return; - } - - const toolCalls = choice.message.tool_calls; - - if (!Array.isArray(toolCalls)) { - return; - } - - for (const toolCall of toolCalls) { - if (typeof toolCall.function !== 'object') { - continue; - } - - /** @type {FunctionToolCall} */ - const args = toolCall.function; - console.log('Function tool call:', toolCall); - await eventSource.emit(event_types.LLM_FUNCTION_TOOL_CALL, args); - } - } - - if ([chat_completion_sources.CLAUDE].includes(oai_settings.chat_completion_source)) { - if (!Array.isArray(data?.content)) { - return; - } - - for (const content of data.content) { - if (content.type === 'tool_use') { - /** @type {FunctionToolCall} */ - const args = { name: content.name, arguments: JSON.stringify(content.input) }; - await eventSource.emit(event_types.LLM_FUNCTION_TOOL_CALL, args); - } - } - } - - if ([chat_completion_sources.COHERE].includes(oai_settings.chat_completion_source)) { - if (!Array.isArray(data?.tool_calls)) { - return; - } - - for (const toolCall of data.tool_calls) { - /** @type {FunctionToolCall} */ - const args = { name: toolCall.name, arguments: JSON.stringify(toolCall.parameters) }; - console.log('Function tool call:', toolCall); - await eventSource.emit(event_types.LLM_FUNCTION_TOOL_CALL, args); - } - } -} - -export function isFunctionCallingSupported() { - if (main_api !== 'openai') { - return false; - } - - if (!oai_settings.function_calling) { - return false; - } - - const supportedSources = [ - chat_completion_sources.OPENAI, - chat_completion_sources.COHERE, - chat_completion_sources.CUSTOM, - chat_completion_sources.MISTRALAI, - chat_completion_sources.CLAUDE, - chat_completion_sources.OPENROUTER, - chat_completion_sources.GROQ, - ]; - return supportedSources.includes(oai_settings.chat_completion_source); -} - function getStreamingReply(data) { if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) { return data?.delta?.text || ''; @@ -4019,7 +3886,7 @@ async function onModelChange() { $('#openai_max_context').attr('max', max_32k); } else if (value === 'text-bison-001') { $('#openai_max_context').attr('max', max_8k); - // The ultra endpoints are possibly dead: + // The ultra endpoints are possibly dead: } else if (value.includes('gemini-1.0-ultra') || value === 'gemini-ultra') { $('#openai_max_context').attr('max', max_32k); } else { diff --git a/public/scripts/tool-calling.js b/public/scripts/tool-calling.js new file mode 100644 index 000000000..c2f6c209f --- /dev/null +++ b/public/scripts/tool-calling.js @@ -0,0 +1,381 @@ +import { chat, main_api } from '../script.js'; +import { chat_completion_sources, oai_settings } from './openai.js'; + +/** + * @typedef {object} ToolInvocation + * @property {string} id - A unique identifier for the tool invocation. + * @property {string} name - The name of the tool. + * @property {string} parameters - The parameters for the tool invocation. + * @property {string} result - The result of the tool invocation. + */ + +/** + * A class that represents a tool definition. + */ +class ToolDefinition { + /** + * A unique name for the tool. + * @type {string} + */ + #name; + + /** + * A description of what the tool does. + * @type {string} + */ + #description; + + /** + * A JSON schema for the parameters that the tool accepts. + * @type {object} + */ + #parameters; + + /** + * A function that will be called when the tool is executed. + * @type {function} + */ + #action; + + /** + * Creates a new ToolDefinition. + * @param {string} name A unique name for the tool. + * @param {string} description A description of what the tool does. + * @param {object} parameters A JSON schema for the parameters that the tool accepts. + * @param {function} action A function that will be called when the tool is executed. + */ + constructor(name, description, parameters, action) { + this.#name = name; + this.#description = description; + this.#parameters = parameters; + this.#action = action; + } + + /** + * Converts the ToolDefinition to an OpenAI API representation + * @returns {object} OpenAI API representation of the tool. + */ + toFunctionOpenAI() { + return { + type: 'function', + function: { + name: this.#name, + description: this.#description, + parameters: this.#parameters, + }, + }; + } + + /** + * Invokes the tool with the given parameters. + * @param {object} parameters The parameters to pass to the tool. + * @returns {Promise} The result of the tool's action function. + */ + async invoke(parameters) { + return await this.#action(parameters); + } +} + +/** + * A class that manages the registration and invocation of tools. + */ +export class ToolManager { + /** + * A map of tool names to tool definitions. + * @type {Map} + */ + static #tools = new Map(); + + /** + * Returns an Array of all tools that have been registered. + * @type {ToolDefinition[]} + */ + static get tools() { + return Array.from(this.#tools.values()); + } + + /** + * Registers a new tool with the tool registry. + * @param {string} name The name of the tool. + * @param {string} description A description of what the tool does. + * @param {object} parameters A JSON schema for the parameters that the tool accepts. + * @param {function} action A function that will be called when the tool is executed. + */ + static registerFunctionTool(name, description, parameters, action) { + if (this.#tools.has(name)) { + console.warn(`A tool with the name "${name}" has already been registered. The definition will be overwritten.`); + } + + const definition = new ToolDefinition(name, description, parameters, action); + this.#tools.set(name, definition); + } + + /** + * Removes a tool from the tool registry. + * @param {string} name The name of the tool to unregister. + */ + static unregisterFunctionTool(name) { + if (!this.#tools.has(name)) { + console.warn(`No tool with the name "${name}" has been registered.`); + return; + } + + this.#tools.delete(name); + } + + /** + * Invokes a tool by name. Returns the result of the tool's action function. + * @param {string} name The name of the tool to invoke. + * @param {object} parameters Function parameters. For example, if the tool requires a "name" parameter, you would pass {name: "value"}. + * @returns {Promise} The result of the tool's action function. If an error occurs, null is returned. Non-string results are JSON-stringified. + */ + static async invokeFunctionTool(name, parameters) { + try { + if (!this.#tools.has(name)) { + throw new Error(`No tool with the name "${name}" has been registered.`); + } + + const invokeParameters = typeof parameters === 'string' ? JSON.parse(parameters) : parameters; + const tool = this.#tools.get(name); + const result = await tool.invoke(invokeParameters); + return typeof result === 'string' ? result : JSON.stringify(result); + } catch (error) { + console.error(`An error occurred while invoking the tool "${name}":`, error); + return null; + } + } + + /** + * Register function tools for the next chat completion request. + * @param {object} data Generation data + */ + static async registerFunctionToolsOpenAI(data) { + const tools = []; + + for (const tool of ToolManager.tools) { + tools.push(tool.toFunctionOpenAI()); + } + + if (tools.length) { + console.log('Registered function tools:', tools); + + data['tools'] = tools; + data['tool_choice'] = 'auto'; + } + } + + /** + * Utility function to parse tool calls from a parsed response. + * @param {any[]} toolCalls The tool calls to update. + * @param {any} parsed The parsed response from the OpenAI API. + * @returns {void} + */ + static parseToolCalls(toolCalls, parsed) { + if (!Array.isArray(parsed?.choices)) { + return; + } + for (const choice of parsed.choices) { + const choiceIndex = (typeof choice.index === 'number') ? choice.index : null; + const choiceDelta = choice.delta; + + if (choiceIndex === null || !choiceDelta) { + continue; + } + + const toolCallDeltas = choiceDelta?.tool_calls; + + if (!Array.isArray(toolCallDeltas)) { + continue; + } + + if (!Array.isArray(toolCalls[choiceIndex])) { + toolCalls[choiceIndex] = []; + } + + for (const toolCallDelta of toolCallDeltas) { + const toolCallIndex = (typeof toolCallDelta?.index === 'number') ? toolCallDelta.index : null; + + if (toolCallIndex === null) { + continue; + } + + if (toolCalls[choiceIndex][toolCallIndex] === undefined) { + toolCalls[choiceIndex][toolCallIndex] = {}; + } + + const targetToolCall = toolCalls[choiceIndex][toolCallIndex]; + + ToolManager.#applyToolCallDelta(targetToolCall, toolCallDelta); + } + } + } + + static #applyToolCallDelta(target, delta) { + for (const key in delta) { + if (!delta.hasOwnProperty(key)) continue; + + const deltaValue = delta[key]; + const targetValue = target[key]; + + if (deltaValue === null || deltaValue === undefined) { + target[key] = deltaValue; + continue; + } + + if (typeof deltaValue === 'string') { + if (typeof targetValue === 'string') { + // Concatenate strings + target[key] = targetValue + deltaValue; + } else { + target[key] = deltaValue; + } + } else if (typeof deltaValue === 'object' && !Array.isArray(deltaValue)) { + if (typeof targetValue !== 'object' || targetValue === null || Array.isArray(targetValue)) { + target[key] = {}; + } + // Recursively apply deltas to nested objects + ToolManager.#applyToolCallDelta(target[key], deltaValue); + } else { + // Assign other types directly + target[key] = deltaValue; + } + } + } + + static isFunctionCallingSupported() { + if (main_api !== 'openai') { + return false; + } + + if (!oai_settings.function_calling) { + return false; + } + + const supportedSources = [ + chat_completion_sources.OPENAI, + //chat_completion_sources.COHERE, + chat_completion_sources.CUSTOM, + chat_completion_sources.MISTRALAI, + //chat_completion_sources.CLAUDE, + chat_completion_sources.OPENROUTER, + chat_completion_sources.GROQ, + ]; + return supportedSources.includes(oai_settings.chat_completion_source); + } + + static #getToolCallsFromData(data) { + // Parsed tool calls from streaming data + if (Array.isArray(data) && data.length > 0) { + return data[0]; + } + + // Parsed tool calls from non-streaming data + if (!Array.isArray(data?.choices)) { + return; + } + + // Find a choice with 0-index + const choice = data.choices.find(choice => choice.index === 0); + + if (!choice) { + return; + } + + return choice.message.tool_calls; + } + + /** + * Check for function tool calls in the response data and invoke them. + * @param {any} data Reply data + * @returns {Promise} Successful tool invocations + */ + static async checkFunctionToolCalls(data) { + if (!ToolManager.isFunctionCallingSupported()) { + return []; + } + + /** @type {ToolInvocation[]} */ + const invocations = []; + const toolCalls = ToolManager.#getToolCallsFromData(data); + const oaiCompat = [ + chat_completion_sources.OPENAI, + chat_completion_sources.CUSTOM, + chat_completion_sources.MISTRALAI, + chat_completion_sources.OPENROUTER, + chat_completion_sources.GROQ, + ]; + + if (oaiCompat.includes(oai_settings.chat_completion_source)) { + if (!Array.isArray(toolCalls)) { + return; + } + + for (const toolCall of toolCalls) { + if (typeof toolCall.function !== 'object') { + continue; + } + + console.log('Function tool call:', toolCall); + const id = toolCall.id; + const parameters = toolCall.function.arguments; + const name = toolCall.function.name; + + toastr.info('Invoking function tool: ' + name); + const result = await ToolManager.invokeFunctionTool(name, parameters); + toastr.info('Function tool result: ' + result); + + // Save a successful invocation + if (result) { + invocations.push({ id, name, result, parameters }); + } + } + } + + /* + if ([chat_completion_sources.CLAUDE].includes(oai_settings.chat_completion_source)) { + if (!Array.isArray(data?.content)) { + return; + } + + for (const content of data.content) { + if (content.type === 'tool_use') { + const args = { name: content.name, arguments: JSON.stringify(content.input) }; + } + } + } + */ + + /* + if ([chat_completion_sources.COHERE].includes(oai_settings.chat_completion_source)) { + if (!Array.isArray(data?.tool_calls)) { + return; + } + + for (const toolCall of data.tool_calls) { + const args = { name: toolCall.name, arguments: JSON.stringify(toolCall.parameters) }; + console.log('Function tool call:', toolCall); + } + } + */ + + return invocations; + } + + /** + * Saves function tool invocations to the last user chat message extra metadata. + * @param {ToolInvocation[]} invocations + */ + static saveFunctionToolInvocations(invocations) { + for (let index = chat.length - 1; index >= 0; index--) { + const message = chat[index]; + if (message.is_user) { + if (!message.extra || typeof message.extra !== 'object') { + message.extra = {}; + } + message.extra.tool_invocations = invocations; + debugger; + break; + } + } + } +} diff --git a/src/endpoints/backends/chat-completions.js b/src/endpoints/backends/chat-completions.js index ca86db051..095a65ceb 100644 --- a/src/endpoints/backends/chat-completions.js +++ b/src/endpoints/backends/chat-completions.js @@ -121,18 +121,20 @@ async function sendClaudeRequest(request, response) { ? [{ type: 'text', text: convertedPrompt.systemPrompt, cache_control: { type: 'ephemeral' } }] : convertedPrompt.systemPrompt; } + /* if (Array.isArray(request.body.tools) && request.body.tools.length > 0) { // Claude doesn't do prefills on function calls, and doesn't allow empty messages if (convertedPrompt.messages.length && convertedPrompt.messages[convertedPrompt.messages.length - 1].role === 'assistant') { convertedPrompt.messages.push({ role: 'user', content: '.' }); } additionalHeaders['anthropic-beta'] = 'tools-2024-05-16'; - requestBody.tool_choice = { type: request.body.tool_choice === 'required' ? 'any' : 'auto' }; + requestBody.tool_choice = { type: request.body.tool_choice }; requestBody.tools = request.body.tools .filter(tool => tool.type === 'function') .map(tool => tool.function) .map(fn => ({ name: fn.name, description: fn.description, input_schema: fn.parameters })); } + */ if (enableSystemPromptCache) { additionalHeaders['anthropic-beta'] = 'prompt-caching-2024-07-31'; } @@ -479,7 +481,7 @@ async function sendMistralAIRequest(request, response) { if (Array.isArray(request.body.tools) && request.body.tools.length > 0) { requestBody['tools'] = request.body.tools; - requestBody['tool_choice'] = request.body.tool_choice === 'required' ? 'any' : 'auto'; + requestBody['tool_choice'] = request.body.tool_choice; } const config = { @@ -549,11 +551,13 @@ async function sendCohereRequest(request, response) { }); } + /* if (Array.isArray(request.body.tools) && request.body.tools.length > 0) { tools.push(...convertCohereTools(request.body.tools)); // Can't have both connectors and tools in the same request connectors.splice(0, connectors.length); } + */ // https://docs.cohere.com/reference/chat const requestBody = { @@ -910,18 +914,6 @@ router.post('/generate', jsonParser, function (request, response) { apiKey = readSecret(request.user.directories, SECRET_KEYS.GROQ); headers = {}; bodyParams = {}; - - // 'required' tool choice is not supported by Groq - if (request.body.tool_choice === 'required') { - if (Array.isArray(request.body.tools) && request.body.tools.length > 0) { - request.body.tool_choice = request.body.tools.length > 1 - ? 'auto' : - { type: 'function', function: { name: request.body.tools[0]?.function?.name } }; - - } else { - request.body.tool_choice = 'none'; - } - } } else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.ZEROONEAI) { apiUrl = API_01AI; apiKey = readSecret(request.user.directories, SECRET_KEYS.ZEROONEAI); @@ -958,7 +950,7 @@ router.post('/generate', jsonParser, function (request, response) { controller.abort(); }); - if (!isTextCompletion) { + if (!isTextCompletion && Array.isArray(request.body.tools) && request.body.tools.length > 0) { bodyParams['tools'] = request.body.tools; bodyParams['tool_choice'] = request.body.tool_choice; }