Implement function tool calling for OpenAI

This commit is contained in:
Cohee 2024-10-02 01:45:57 +03:00
parent 8006795897
commit c94c06ed4d
3 changed files with 52 additions and 9 deletions

View File

@ -4408,7 +4408,7 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
if (ToolManager.isFunctionCallingSupported() && Array.isArray(streamingProcessor.toolCalls) && streamingProcessor.toolCalls.length) {
const invocations = await ToolManager.checkFunctionToolCalls(streamingProcessor.toolCalls);
if (invocations.length) {
if (Array.isArray(invocations) && invocations.length) {
const lastMessage = chat[chat.length - 1];
const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result);
if (shouldDeleteMessage) {
@ -4457,7 +4457,7 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
if (ToolManager.isFunctionCallingSupported()) {
const invocations = await ToolManager.checkFunctionToolCalls(data);
if (invocations.length) {
if (Array.isArray(invocations) && invocations.length) {
ToolManager.saveFunctionToolInvocations(invocations);
return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun);
}

View File

@ -454,7 +454,8 @@ function setOpenAIMessages(chat) {
if (role == 'user' && oai_settings.wrap_in_quotes) content = `"${content}"`;
const name = chat[j]['name'];
const image = chat[j]?.extra?.image;
messages[i] = { 'role': role, 'content': content, name: name, 'image': image };
const invocations = chat[j]?.extra?.tool_invocations;
messages[i] = { 'role': role, 'content': content, name: name, 'image': image, 'invocations': invocations };
j++;
}
@ -702,6 +703,7 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul
}
const imageInlining = isImageInliningSupported();
const toolCalling = ToolManager.isFunctionCallingSupported();
// Insert chat messages as long as there is budget available
const chatPool = [...messages].reverse();
@ -723,6 +725,24 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul
await chatMessage.addImage(chatPrompt.image);
}
if (toolCalling && Array.isArray(chatPrompt.invocations)) {
/** @type {import('./tool-calling.js').ToolInvocation[]} */
const invocations = chatPrompt.invocations.slice().reverse();
const toolCallMessage = new Message('assistant', undefined, 'toolCall-' + chatMessage.identifier);
toolCallMessage.setToolCalls(invocations);
if (chatCompletion.canAfford(toolCallMessage)) {
for (const invocation of invocations) {
const toolResultMessage = new Message('tool', invocation.result, invocation.id);
const canAfford = chatCompletion.canAfford(toolResultMessage) && chatCompletion.canAfford(toolCallMessage);
if (!canAfford) {
break;
}
chatCompletion.insertAtStart(toolResultMessage, 'chatHistory');
}
chatCompletion.insertAtStart(toolCallMessage, 'chatHistory');
}
}
if (chatCompletion.canAfford(chatMessage)) {
if (type === 'continue' && oai_settings.continue_prefill && chatPrompt === firstNonInjected) {
// in case we are using continue_prefill and the latest message is an assistant message, we want to prepend the users assistant prefill on the message
@ -2193,6 +2213,8 @@ class Message {
content;
/** @type {string} */
name;
/** @type {object} */
tool_call = null;
/**
* @constructor
@ -2217,6 +2239,22 @@ class Message {
}
}
/**
* Reconstruct the message from a tool invocation.
* @param {import('./tool-calling.js').ToolInvocation[]} invocations
*/
setToolCalls(invocations) {
this.tool_calls = invocations.map(i => ({
id: i.id,
type: 'function',
function: {
arguments: i.parameters,
name: i.name,
},
}));
this.tokens = tokenHandler.count({ role: this.role, tool_calls: JSON.stringify(this.tool_calls) });
}
setName(name) {
this.name = name;
this.tokens = tokenHandler.count({ role: this.role, content: this.content, name: this.name });
@ -2564,7 +2602,7 @@ export class ChatCompletion {
this.checkTokenBudget(message, message.identifier);
const index = this.findMessageIndex(identifier);
if (message.content) {
if (message.content || message.tool_calls) {
if ('start' === position) this.messages.collection[index].collection.unshift(message);
else if ('end' === position) this.messages.collection[index].collection.push(message);
else if (typeof position === 'number') this.messages.collection[index].collection.splice(position, 0, message);
@ -2633,8 +2671,14 @@ export class ChatCompletion {
for (let item of this.messages.collection) {
if (item instanceof MessageCollection) {
chat.push(...item.getChat());
} else if (item instanceof Message && item.content) {
const message = { role: item.role, content: item.content, ...(item.name ? { name: item.name } : {}) };
} else if (item instanceof Message && (item.content || item.tool_calls)) {
const message = {
role: item.role,
content: item.content,
...(item.name ? { name: item.name } : {}),
...(item.tool_calls ? { tool_calls: item.tool_calls } : {}),
...(item.role === 'tool' ? { tool_call_id: item.identifier } : {}),
};
chat.push(message);
} else {
this.log(`Skipping invalid or empty message in collection: ${JSON.stringify(item)}`);

View File

@ -307,7 +307,7 @@ export class ToolManager {
if (oaiCompat.includes(oai_settings.chat_completion_source)) {
if (!Array.isArray(toolCalls)) {
return;
return [];
}
for (const toolCall of toolCalls) {
@ -363,7 +363,7 @@ export class ToolManager {
/**
* Saves function tool invocations to the last user chat message extra metadata.
* @param {ToolInvocation[]} invocations
* @param {ToolInvocation[]} invocations Successful tool invocations
*/
static saveFunctionToolInvocations(invocations) {
for (let index = chat.length - 1; index >= 0; index--) {
@ -373,7 +373,6 @@ export class ToolManager {
message.extra = {};
}
message.extra.tool_invocations = invocations;
debugger;
break;
}
}