Implement function tool calling for OpenAI
This commit is contained in:
parent
8006795897
commit
c94c06ed4d
|
@ -4408,7 +4408,7 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
|
|||
|
||||
if (ToolManager.isFunctionCallingSupported() && Array.isArray(streamingProcessor.toolCalls) && streamingProcessor.toolCalls.length) {
|
||||
const invocations = await ToolManager.checkFunctionToolCalls(streamingProcessor.toolCalls);
|
||||
if (invocations.length) {
|
||||
if (Array.isArray(invocations) && invocations.length) {
|
||||
const lastMessage = chat[chat.length - 1];
|
||||
const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result);
|
||||
if (shouldDeleteMessage) {
|
||||
|
@ -4457,7 +4457,7 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
|
|||
|
||||
if (ToolManager.isFunctionCallingSupported()) {
|
||||
const invocations = await ToolManager.checkFunctionToolCalls(data);
|
||||
if (invocations.length) {
|
||||
if (Array.isArray(invocations) && invocations.length) {
|
||||
ToolManager.saveFunctionToolInvocations(invocations);
|
||||
return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun);
|
||||
}
|
||||
|
|
|
@ -454,7 +454,8 @@ function setOpenAIMessages(chat) {
|
|||
if (role == 'user' && oai_settings.wrap_in_quotes) content = `"${content}"`;
|
||||
const name = chat[j]['name'];
|
||||
const image = chat[j]?.extra?.image;
|
||||
messages[i] = { 'role': role, 'content': content, name: name, 'image': image };
|
||||
const invocations = chat[j]?.extra?.tool_invocations;
|
||||
messages[i] = { 'role': role, 'content': content, name: name, 'image': image, 'invocations': invocations };
|
||||
j++;
|
||||
}
|
||||
|
||||
|
@ -702,6 +703,7 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul
|
|||
}
|
||||
|
||||
const imageInlining = isImageInliningSupported();
|
||||
const toolCalling = ToolManager.isFunctionCallingSupported();
|
||||
|
||||
// Insert chat messages as long as there is budget available
|
||||
const chatPool = [...messages].reverse();
|
||||
|
@ -723,6 +725,24 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul
|
|||
await chatMessage.addImage(chatPrompt.image);
|
||||
}
|
||||
|
||||
if (toolCalling && Array.isArray(chatPrompt.invocations)) {
|
||||
/** @type {import('./tool-calling.js').ToolInvocation[]} */
|
||||
const invocations = chatPrompt.invocations.slice().reverse();
|
||||
const toolCallMessage = new Message('assistant', undefined, 'toolCall-' + chatMessage.identifier);
|
||||
toolCallMessage.setToolCalls(invocations);
|
||||
if (chatCompletion.canAfford(toolCallMessage)) {
|
||||
for (const invocation of invocations) {
|
||||
const toolResultMessage = new Message('tool', invocation.result, invocation.id);
|
||||
const canAfford = chatCompletion.canAfford(toolResultMessage) && chatCompletion.canAfford(toolCallMessage);
|
||||
if (!canAfford) {
|
||||
break;
|
||||
}
|
||||
chatCompletion.insertAtStart(toolResultMessage, 'chatHistory');
|
||||
}
|
||||
chatCompletion.insertAtStart(toolCallMessage, 'chatHistory');
|
||||
}
|
||||
}
|
||||
|
||||
if (chatCompletion.canAfford(chatMessage)) {
|
||||
if (type === 'continue' && oai_settings.continue_prefill && chatPrompt === firstNonInjected) {
|
||||
// in case we are using continue_prefill and the latest message is an assistant message, we want to prepend the users assistant prefill on the message
|
||||
|
@ -2193,6 +2213,8 @@ class Message {
|
|||
content;
|
||||
/** @type {string} */
|
||||
name;
|
||||
/** @type {object} */
|
||||
tool_call = null;
|
||||
|
||||
/**
|
||||
* @constructor
|
||||
|
@ -2217,6 +2239,22 @@ class Message {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reconstruct the message from a tool invocation.
|
||||
* @param {import('./tool-calling.js').ToolInvocation[]} invocations
|
||||
*/
|
||||
setToolCalls(invocations) {
|
||||
this.tool_calls = invocations.map(i => ({
|
||||
id: i.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
arguments: i.parameters,
|
||||
name: i.name,
|
||||
},
|
||||
}));
|
||||
this.tokens = tokenHandler.count({ role: this.role, tool_calls: JSON.stringify(this.tool_calls) });
|
||||
}
|
||||
|
||||
setName(name) {
|
||||
this.name = name;
|
||||
this.tokens = tokenHandler.count({ role: this.role, content: this.content, name: this.name });
|
||||
|
@ -2564,7 +2602,7 @@ export class ChatCompletion {
|
|||
this.checkTokenBudget(message, message.identifier);
|
||||
|
||||
const index = this.findMessageIndex(identifier);
|
||||
if (message.content) {
|
||||
if (message.content || message.tool_calls) {
|
||||
if ('start' === position) this.messages.collection[index].collection.unshift(message);
|
||||
else if ('end' === position) this.messages.collection[index].collection.push(message);
|
||||
else if (typeof position === 'number') this.messages.collection[index].collection.splice(position, 0, message);
|
||||
|
@ -2633,8 +2671,14 @@ export class ChatCompletion {
|
|||
for (let item of this.messages.collection) {
|
||||
if (item instanceof MessageCollection) {
|
||||
chat.push(...item.getChat());
|
||||
} else if (item instanceof Message && item.content) {
|
||||
const message = { role: item.role, content: item.content, ...(item.name ? { name: item.name } : {}) };
|
||||
} else if (item instanceof Message && (item.content || item.tool_calls)) {
|
||||
const message = {
|
||||
role: item.role,
|
||||
content: item.content,
|
||||
...(item.name ? { name: item.name } : {}),
|
||||
...(item.tool_calls ? { tool_calls: item.tool_calls } : {}),
|
||||
...(item.role === 'tool' ? { tool_call_id: item.identifier } : {}),
|
||||
};
|
||||
chat.push(message);
|
||||
} else {
|
||||
this.log(`Skipping invalid or empty message in collection: ${JSON.stringify(item)}`);
|
||||
|
|
|
@ -307,7 +307,7 @@ export class ToolManager {
|
|||
|
||||
if (oaiCompat.includes(oai_settings.chat_completion_source)) {
|
||||
if (!Array.isArray(toolCalls)) {
|
||||
return;
|
||||
return [];
|
||||
}
|
||||
|
||||
for (const toolCall of toolCalls) {
|
||||
|
@ -363,7 +363,7 @@ export class ToolManager {
|
|||
|
||||
/**
|
||||
* Saves function tool invocations to the last user chat message extra metadata.
|
||||
* @param {ToolInvocation[]} invocations
|
||||
* @param {ToolInvocation[]} invocations Successful tool invocations
|
||||
*/
|
||||
static saveFunctionToolInvocations(invocations) {
|
||||
for (let index = chat.length - 1; index >= 0; index--) {
|
||||
|
@ -373,7 +373,6 @@ export class ToolManager {
|
|||
message.extra = {};
|
||||
}
|
||||
message.extra.tool_invocations = invocations;
|
||||
debugger;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue