mirror of
				https://github.com/SillyTavern/SillyTavern.git
				synced 2025-06-05 21:59:27 +02:00 
			
		
		
		
	Implement function tool calling for OpenAI
This commit is contained in:
		| @@ -4408,7 +4408,7 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro | |||||||
|  |  | ||||||
|             if (ToolManager.isFunctionCallingSupported() && Array.isArray(streamingProcessor.toolCalls) && streamingProcessor.toolCalls.length) { |             if (ToolManager.isFunctionCallingSupported() && Array.isArray(streamingProcessor.toolCalls) && streamingProcessor.toolCalls.length) { | ||||||
|                 const invocations = await ToolManager.checkFunctionToolCalls(streamingProcessor.toolCalls); |                 const invocations = await ToolManager.checkFunctionToolCalls(streamingProcessor.toolCalls); | ||||||
|                 if (invocations.length) { |                 if (Array.isArray(invocations) && invocations.length) { | ||||||
|                     const lastMessage = chat[chat.length - 1]; |                     const lastMessage = chat[chat.length - 1]; | ||||||
|                     const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result); |                     const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result); | ||||||
|                     if (shouldDeleteMessage) { |                     if (shouldDeleteMessage) { | ||||||
| @@ -4457,7 +4457,7 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro | |||||||
|  |  | ||||||
|         if (ToolManager.isFunctionCallingSupported()) { |         if (ToolManager.isFunctionCallingSupported()) { | ||||||
|             const invocations = await ToolManager.checkFunctionToolCalls(data); |             const invocations = await ToolManager.checkFunctionToolCalls(data); | ||||||
|             if (invocations.length) { |             if (Array.isArray(invocations) && invocations.length) { | ||||||
|                 ToolManager.saveFunctionToolInvocations(invocations); |                 ToolManager.saveFunctionToolInvocations(invocations); | ||||||
|                 return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun); |                 return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun); | ||||||
|             } |             } | ||||||
|   | |||||||
| @@ -454,7 +454,8 @@ function setOpenAIMessages(chat) { | |||||||
|         if (role == 'user' && oai_settings.wrap_in_quotes) content = `"${content}"`; |         if (role == 'user' && oai_settings.wrap_in_quotes) content = `"${content}"`; | ||||||
|         const name = chat[j]['name']; |         const name = chat[j]['name']; | ||||||
|         const image = chat[j]?.extra?.image; |         const image = chat[j]?.extra?.image; | ||||||
|         messages[i] = { 'role': role, 'content': content, name: name, 'image': image }; |         const invocations = chat[j]?.extra?.tool_invocations; | ||||||
|  |         messages[i] = { 'role': role, 'content': content, name: name, 'image': image, 'invocations': invocations }; | ||||||
|         j++; |         j++; | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -702,6 +703,7 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     const imageInlining = isImageInliningSupported(); |     const imageInlining = isImageInliningSupported(); | ||||||
|  |     const toolCalling = ToolManager.isFunctionCallingSupported(); | ||||||
|  |  | ||||||
|     // Insert chat messages as long as there is budget available |     // Insert chat messages as long as there is budget available | ||||||
|     const chatPool = [...messages].reverse(); |     const chatPool = [...messages].reverse(); | ||||||
| @@ -723,6 +725,24 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul | |||||||
|             await chatMessage.addImage(chatPrompt.image); |             await chatMessage.addImage(chatPrompt.image); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |         if (toolCalling && Array.isArray(chatPrompt.invocations)) { | ||||||
|  |             /** @type {import('./tool-calling.js').ToolInvocation[]} */ | ||||||
|  |             const invocations = chatPrompt.invocations.slice().reverse(); | ||||||
|  |             const toolCallMessage = new Message('assistant', undefined, 'toolCall-' + chatMessage.identifier); | ||||||
|  |             toolCallMessage.setToolCalls(invocations); | ||||||
|  |             if (chatCompletion.canAfford(toolCallMessage)) { | ||||||
|  |                 for (const invocation of invocations) { | ||||||
|  |                     const toolResultMessage = new Message('tool', invocation.result, invocation.id); | ||||||
|  |                     const canAfford = chatCompletion.canAfford(toolResultMessage) && chatCompletion.canAfford(toolCallMessage); | ||||||
|  |                     if (!canAfford) { | ||||||
|  |                         break; | ||||||
|  |                     } | ||||||
|  |                     chatCompletion.insertAtStart(toolResultMessage, 'chatHistory'); | ||||||
|  |                 } | ||||||
|  |                 chatCompletion.insertAtStart(toolCallMessage, 'chatHistory'); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|         if (chatCompletion.canAfford(chatMessage)) { |         if (chatCompletion.canAfford(chatMessage)) { | ||||||
|             if (type === 'continue' && oai_settings.continue_prefill && chatPrompt === firstNonInjected) { |             if (type === 'continue' && oai_settings.continue_prefill && chatPrompt === firstNonInjected) { | ||||||
|                 // in case we are using continue_prefill and the latest message is an assistant message, we want to prepend the users assistant prefill on the message |                 // in case we are using continue_prefill and the latest message is an assistant message, we want to prepend the users assistant prefill on the message | ||||||
| @@ -2193,6 +2213,8 @@ class Message { | |||||||
|     content; |     content; | ||||||
|     /** @type {string} */ |     /** @type {string} */ | ||||||
|     name; |     name; | ||||||
|  |     /** @type {object} */ | ||||||
|  |     tool_call = null; | ||||||
|  |  | ||||||
|     /** |     /** | ||||||
|      * @constructor |      * @constructor | ||||||
| @@ -2217,6 +2239,22 @@ class Message { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     /** | ||||||
|  |      * Reconstruct the message from a tool invocation. | ||||||
|  |      * @param {import('./tool-calling.js').ToolInvocation[]} invocations | ||||||
|  |      */ | ||||||
|  |     setToolCalls(invocations) { | ||||||
|  |         this.tool_calls = invocations.map(i => ({ | ||||||
|  |             id: i.id, | ||||||
|  |             type: 'function', | ||||||
|  |             function: { | ||||||
|  |                 arguments: i.parameters, | ||||||
|  |                 name: i.name, | ||||||
|  |             }, | ||||||
|  |         })); | ||||||
|  |         this.tokens = tokenHandler.count({ role: this.role, tool_calls: JSON.stringify(this.tool_calls) }); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     setName(name) { |     setName(name) { | ||||||
|         this.name = name; |         this.name = name; | ||||||
|         this.tokens = tokenHandler.count({ role: this.role, content: this.content, name: this.name }); |         this.tokens = tokenHandler.count({ role: this.role, content: this.content, name: this.name }); | ||||||
| @@ -2564,7 +2602,7 @@ export class ChatCompletion { | |||||||
|         this.checkTokenBudget(message, message.identifier); |         this.checkTokenBudget(message, message.identifier); | ||||||
|  |  | ||||||
|         const index = this.findMessageIndex(identifier); |         const index = this.findMessageIndex(identifier); | ||||||
|         if (message.content) { |         if (message.content || message.tool_calls) { | ||||||
|             if ('start' === position) this.messages.collection[index].collection.unshift(message); |             if ('start' === position) this.messages.collection[index].collection.unshift(message); | ||||||
|             else if ('end' === position) this.messages.collection[index].collection.push(message); |             else if ('end' === position) this.messages.collection[index].collection.push(message); | ||||||
|             else if (typeof position === 'number') this.messages.collection[index].collection.splice(position, 0, message); |             else if (typeof position === 'number') this.messages.collection[index].collection.splice(position, 0, message); | ||||||
| @@ -2633,8 +2671,14 @@ export class ChatCompletion { | |||||||
|         for (let item of this.messages.collection) { |         for (let item of this.messages.collection) { | ||||||
|             if (item instanceof MessageCollection) { |             if (item instanceof MessageCollection) { | ||||||
|                 chat.push(...item.getChat()); |                 chat.push(...item.getChat()); | ||||||
|             } else if (item instanceof Message && item.content) { |             } else if (item instanceof Message && (item.content || item.tool_calls)) { | ||||||
|                 const message = { role: item.role, content: item.content, ...(item.name ? { name: item.name } : {}) }; |                 const message = { | ||||||
|  |                     role: item.role, | ||||||
|  |                     content: item.content, | ||||||
|  |                     ...(item.name ? { name: item.name } : {}), | ||||||
|  |                     ...(item.tool_calls ? { tool_calls: item.tool_calls } : {}), | ||||||
|  |                     ...(item.role === 'tool' ? { tool_call_id: item.identifier } : {}), | ||||||
|  |                 }; | ||||||
|                 chat.push(message); |                 chat.push(message); | ||||||
|             } else { |             } else { | ||||||
|                 this.log(`Skipping invalid or empty message in collection: ${JSON.stringify(item)}`); |                 this.log(`Skipping invalid or empty message in collection: ${JSON.stringify(item)}`); | ||||||
|   | |||||||
| @@ -307,7 +307,7 @@ export class ToolManager { | |||||||
|  |  | ||||||
|         if (oaiCompat.includes(oai_settings.chat_completion_source)) { |         if (oaiCompat.includes(oai_settings.chat_completion_source)) { | ||||||
|             if (!Array.isArray(toolCalls)) { |             if (!Array.isArray(toolCalls)) { | ||||||
|                 return; |                 return []; | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             for (const toolCall of toolCalls) { |             for (const toolCall of toolCalls) { | ||||||
| @@ -363,7 +363,7 @@ export class ToolManager { | |||||||
|  |  | ||||||
|     /** |     /** | ||||||
|      * Saves function tool invocations to the last user chat message extra metadata. |      * Saves function tool invocations to the last user chat message extra metadata. | ||||||
|      * @param {ToolInvocation[]} invocations |      * @param {ToolInvocation[]} invocations Successful tool invocations | ||||||
|      */ |      */ | ||||||
|     static saveFunctionToolInvocations(invocations) { |     static saveFunctionToolInvocations(invocations) { | ||||||
|         for (let index = chat.length - 1; index >= 0; index--) { |         for (let index = chat.length - 1; index >= 0; index--) { | ||||||
| @@ -373,7 +373,6 @@ export class ToolManager { | |||||||
|                     message.extra = {}; |                     message.extra = {}; | ||||||
|                 } |                 } | ||||||
|                 message.extra.tool_invocations = invocations; |                 message.extra.tool_invocations = invocations; | ||||||
|                 debugger; |  | ||||||
|                 break; |                 break; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user