Implement function tool calling for OpenAI

This commit is contained in:
Cohee 2024-10-02 01:45:57 +03:00
parent 8006795897
commit c94c06ed4d
3 changed files with 52 additions and 9 deletions

View File

@ -4408,7 +4408,7 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
if (ToolManager.isFunctionCallingSupported() && Array.isArray(streamingProcessor.toolCalls) && streamingProcessor.toolCalls.length) { if (ToolManager.isFunctionCallingSupported() && Array.isArray(streamingProcessor.toolCalls) && streamingProcessor.toolCalls.length) {
const invocations = await ToolManager.checkFunctionToolCalls(streamingProcessor.toolCalls); const invocations = await ToolManager.checkFunctionToolCalls(streamingProcessor.toolCalls);
if (invocations.length) { if (Array.isArray(invocations) && invocations.length) {
const lastMessage = chat[chat.length - 1]; const lastMessage = chat[chat.length - 1];
const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result); const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result);
if (shouldDeleteMessage) { if (shouldDeleteMessage) {
@ -4457,7 +4457,7 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
if (ToolManager.isFunctionCallingSupported()) { if (ToolManager.isFunctionCallingSupported()) {
const invocations = await ToolManager.checkFunctionToolCalls(data); const invocations = await ToolManager.checkFunctionToolCalls(data);
if (invocations.length) { if (Array.isArray(invocations) && invocations.length) {
ToolManager.saveFunctionToolInvocations(invocations); ToolManager.saveFunctionToolInvocations(invocations);
return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun); return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun);
} }

View File

@ -454,7 +454,8 @@ function setOpenAIMessages(chat) {
if (role == 'user' && oai_settings.wrap_in_quotes) content = `"${content}"`; if (role == 'user' && oai_settings.wrap_in_quotes) content = `"${content}"`;
const name = chat[j]['name']; const name = chat[j]['name'];
const image = chat[j]?.extra?.image; const image = chat[j]?.extra?.image;
messages[i] = { 'role': role, 'content': content, name: name, 'image': image }; const invocations = chat[j]?.extra?.tool_invocations;
messages[i] = { 'role': role, 'content': content, name: name, 'image': image, 'invocations': invocations };
j++; j++;
} }
@ -702,6 +703,7 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul
} }
const imageInlining = isImageInliningSupported(); const imageInlining = isImageInliningSupported();
const toolCalling = ToolManager.isFunctionCallingSupported();
// Insert chat messages as long as there is budget available // Insert chat messages as long as there is budget available
const chatPool = [...messages].reverse(); const chatPool = [...messages].reverse();
@ -723,6 +725,24 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul
await chatMessage.addImage(chatPrompt.image); await chatMessage.addImage(chatPrompt.image);
} }
if (toolCalling && Array.isArray(chatPrompt.invocations)) {
/** @type {import('./tool-calling.js').ToolInvocation[]} */
const invocations = chatPrompt.invocations.slice().reverse();
const toolCallMessage = new Message('assistant', undefined, 'toolCall-' + chatMessage.identifier);
toolCallMessage.setToolCalls(invocations);
if (chatCompletion.canAfford(toolCallMessage)) {
for (const invocation of invocations) {
const toolResultMessage = new Message('tool', invocation.result, invocation.id);
const canAfford = chatCompletion.canAfford(toolResultMessage) && chatCompletion.canAfford(toolCallMessage);
if (!canAfford) {
break;
}
chatCompletion.insertAtStart(toolResultMessage, 'chatHistory');
}
chatCompletion.insertAtStart(toolCallMessage, 'chatHistory');
}
}
if (chatCompletion.canAfford(chatMessage)) { if (chatCompletion.canAfford(chatMessage)) {
if (type === 'continue' && oai_settings.continue_prefill && chatPrompt === firstNonInjected) { if (type === 'continue' && oai_settings.continue_prefill && chatPrompt === firstNonInjected) {
// in case we are using continue_prefill and the latest message is an assistant message, we want to prepend the users assistant prefill on the message // in case we are using continue_prefill and the latest message is an assistant message, we want to prepend the users assistant prefill on the message
@ -2193,6 +2213,8 @@ class Message {
content; content;
/** @type {string} */ /** @type {string} */
name; name;
/** @type {object} */
tool_call = null;
/** /**
* @constructor * @constructor
@ -2217,6 +2239,22 @@ class Message {
} }
} }
/**
* Reconstruct the message from a tool invocation.
* @param {import('./tool-calling.js').ToolInvocation[]} invocations
*/
setToolCalls(invocations) {
this.tool_calls = invocations.map(i => ({
id: i.id,
type: 'function',
function: {
arguments: i.parameters,
name: i.name,
},
}));
this.tokens = tokenHandler.count({ role: this.role, tool_calls: JSON.stringify(this.tool_calls) });
}
setName(name) { setName(name) {
this.name = name; this.name = name;
this.tokens = tokenHandler.count({ role: this.role, content: this.content, name: this.name }); this.tokens = tokenHandler.count({ role: this.role, content: this.content, name: this.name });
@ -2564,7 +2602,7 @@ export class ChatCompletion {
this.checkTokenBudget(message, message.identifier); this.checkTokenBudget(message, message.identifier);
const index = this.findMessageIndex(identifier); const index = this.findMessageIndex(identifier);
if (message.content) { if (message.content || message.tool_calls) {
if ('start' === position) this.messages.collection[index].collection.unshift(message); if ('start' === position) this.messages.collection[index].collection.unshift(message);
else if ('end' === position) this.messages.collection[index].collection.push(message); else if ('end' === position) this.messages.collection[index].collection.push(message);
else if (typeof position === 'number') this.messages.collection[index].collection.splice(position, 0, message); else if (typeof position === 'number') this.messages.collection[index].collection.splice(position, 0, message);
@ -2633,8 +2671,14 @@ export class ChatCompletion {
for (let item of this.messages.collection) { for (let item of this.messages.collection) {
if (item instanceof MessageCollection) { if (item instanceof MessageCollection) {
chat.push(...item.getChat()); chat.push(...item.getChat());
} else if (item instanceof Message && item.content) { } else if (item instanceof Message && (item.content || item.tool_calls)) {
const message = { role: item.role, content: item.content, ...(item.name ? { name: item.name } : {}) }; const message = {
role: item.role,
content: item.content,
...(item.name ? { name: item.name } : {}),
...(item.tool_calls ? { tool_calls: item.tool_calls } : {}),
...(item.role === 'tool' ? { tool_call_id: item.identifier } : {}),
};
chat.push(message); chat.push(message);
} else { } else {
this.log(`Skipping invalid or empty message in collection: ${JSON.stringify(item)}`); this.log(`Skipping invalid or empty message in collection: ${JSON.stringify(item)}`);

View File

@ -307,7 +307,7 @@ export class ToolManager {
if (oaiCompat.includes(oai_settings.chat_completion_source)) { if (oaiCompat.includes(oai_settings.chat_completion_source)) {
if (!Array.isArray(toolCalls)) { if (!Array.isArray(toolCalls)) {
return; return [];
} }
for (const toolCall of toolCalls) { for (const toolCall of toolCalls) {
@ -363,7 +363,7 @@ export class ToolManager {
/** /**
* Saves function tool invocations to the last user chat message extra metadata. * Saves function tool invocations to the last user chat message extra metadata.
* @param {ToolInvocation[]} invocations * @param {ToolInvocation[]} invocations Successful tool invocations
*/ */
static saveFunctionToolInvocations(invocations) { static saveFunctionToolInvocations(invocations) {
for (let index = chat.length - 1; index >= 0; index--) { for (let index = chat.length - 1; index >= 0; index--) {
@ -373,7 +373,6 @@ export class ToolManager {
message.extra = {}; message.extra = {};
} }
message.extra.tool_invocations = invocations; message.extra.tool_invocations = invocations;
debugger;
break; break;
} }
} }