Extension framework for function tool calling

This commit is contained in:
Cohee
2024-05-25 15:31:57 +03:00
parent 439ef0dc5e
commit a20c6bb01e
5 changed files with 171 additions and 1 deletions

View File

@ -278,6 +278,7 @@ const default_settings = {
inline_image_quality: 'low',
bypass_status_check: false,
continue_prefill: false,
function_calling: false,
names_behavior: character_names_behavior.NONE,
continue_postfix: continue_postfix_types.SPACE,
custom_prompt_post_processing: custom_prompt_post_processing_types.NONE,
@ -355,6 +356,7 @@ const oai_settings = {
inline_image_quality: 'low',
bypass_status_check: false,
continue_prefill: false,
function_calling: false,
names_behavior: character_names_behavior.NONE,
continue_postfix: continue_postfix_types.SPACE,
custom_prompt_post_processing: custom_prompt_post_processing_types.NONE,
@ -1851,6 +1853,10 @@ async function sendOpenAIRequest(type, messages, signal) {
await eventSource.emit(event_types.CHAT_COMPLETION_SETTINGS_READY, generate_data);
if (isFunctionCallingSupported()) {
await registerFunctionTools(type, generate_data);
}
const generate_url = '/api/backends/chat-completions/generate';
const response = await fetch(generate_url, {
method: 'POST',
@ -1907,10 +1913,107 @@ async function sendOpenAIRequest(type, messages, signal) {
delay(1).then(() => saveLogprobsForActiveMessage(logprobs, null));
}
if (isFunctionCallingSupported()) {
await checkFunctionToolCalls(data);
}
return data;
}
}
/**
* Register function tools for the next chat completion request.
* @param {string} type Generation type
* @param {object} data Generation data
*/
async function registerFunctionTools(type, data) {
let toolChoice = 'auto';
const tools = [];
/**
* @type {registerFunctionTool}
*/
const registerFunctionTool = (name, description, parameters, required) => {
tools.push({
type: 'function',
function: {
name,
description,
parameters,
},
});
if (required) {
toolChoice = 'required';
}
};
/**
* @type {FunctionToolRegister}
*/
const args = {
type,
data,
registerFunctionTool,
};
await eventSource.emit(event_types.LLM_FUNCTION_TOOL_REGISTER, args);
if (tools.length) {
console.log('Registered function tools:', tools);
data['tools'] = tools;
data['tool_choice'] = toolChoice;
}
}
async function checkFunctionToolCalls(data) {
if (!Array.isArray(data.choices)) {
return;
}
// Find a choice with 0-index
const choice = data.choices.find(choice => choice.index === 0);
if (!choice) {
return;
}
const toolCalls = choice.message.tool_calls;
if (!Array.isArray(toolCalls)) {
return;
}
for (const toolCall of toolCalls) {
if (toolCall.type !== 'function') {
continue;
}
/** @type {FunctionToolCall} */
const args = toolCall.function;
console.log('Function tool call:', toolCall);
await eventSource.emit(event_types.LLM_FUNCTION_TOOL_CALL, args);
data.allowEmptyResponse = true;
}
}
export function isFunctionCallingSupported() {
if (main_api !== 'openai') {
return false;
}
if (!oai_settings.function_calling) {
return false;
}
const supportedSources = [
chat_completion_sources.OPENAI,
chat_completion_sources.CUSTOM,
];
return supportedSources.includes(oai_settings.chat_completion_source);
}
function getStreamingReply(data) {
if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) {
return data?.delta?.text || '';
@ -2781,6 +2884,7 @@ function loadOpenAISettings(data, settings) {
oai_settings.continue_prefill = settings.continue_prefill ?? default_settings.continue_prefill;
oai_settings.names_behavior = settings.names_behavior ?? default_settings.names_behavior;
oai_settings.continue_postfix = settings.continue_postfix ?? default_settings.continue_postfix;
oai_settings.function_calling = settings.function_calling ?? default_settings.function_calling;
// Migrate from old settings
if (settings.names_in_completion === true) {
@ -2849,6 +2953,7 @@ function loadOpenAISettings(data, settings) {
$('#openrouter_providers_chat').val(oai_settings.openrouter_providers).trigger('change');
$('#squash_system_messages').prop('checked', oai_settings.squash_system_messages);
$('#continue_prefill').prop('checked', oai_settings.continue_prefill);
$('#openai_function_calling').prop('checked', oai_settings.function_calling);
if (settings.impersonation_prompt !== undefined) oai_settings.impersonation_prompt = settings.impersonation_prompt;
$('#impersonation_prompt_textarea').val(oai_settings.impersonation_prompt);
@ -3132,6 +3237,7 @@ async function saveOpenAIPreset(name, settings, triggerUi = true) {
bypass_status_check: settings.bypass_status_check,
continue_prefill: settings.continue_prefill,
continue_postfix: settings.continue_postfix,
function_calling: settings.function_calling,
seed: settings.seed,
n: settings.n,
};
@ -3518,6 +3624,7 @@ function onSettingsPresetChange() {
inline_image_quality: ['#openai_inline_image_quality', 'inline_image_quality', false],
continue_prefill: ['#continue_prefill', 'continue_prefill', true],
continue_postfix: ['#continue_postfix', 'continue_postfix', false],
function_calling: ['#openai_function_calling', 'function_calling', true],
seed: ['#seed_openai', 'seed', false],
n: ['#n_openai', 'n', false],
};
@ -4785,6 +4892,11 @@ $(document).ready(async function () {
saveSettingsDebounced();
});
$('#openai_function_calling').on('input', function () {
oai_settings.function_calling = !!$(this).prop('checked');
saveSettingsDebounced();
});
$('#seed_openai').on('input', function () {
oai_settings.seed = Number($(this).val());
saveSettingsDebounced();