Gemini: Add tool calling

This commit is contained in:
Cohee
2025-03-06 23:52:35 +02:00
parent e9cf606c70
commit c9277cec28
5 changed files with 92 additions and 6 deletions

View File

@@ -1962,7 +1962,7 @@
</span>
</div>
</div>
<div class="range-block" data-source="openai,cohere,mistralai,custom,claude,openrouter,groq,deepseek">
<div class="range-block" data-source="openai,cohere,mistralai,custom,claude,openrouter,groq,deepseek,makersuite">
<label for="openai_function_calling" class="checkbox_label flexWrap widthFreeExpand">
<input id="openai_function_calling" type="checkbox" />
<span data-i18n="Enable function calling">Enable function calling</span>

View File

@@ -137,9 +137,14 @@ async function* parseStreamData(json) {
else if (Array.isArray(json.candidates)) {
for (let i = 0; i < json.candidates.length; i++) {
const isNotPrimary = json.candidates?.[0]?.index > 0;
const hasToolCalls = json?.candidates?.[0]?.content?.parts?.some(p => p?.functionCall);
if (isNotPrimary || json.candidates.length === 0) {
return null;
}
if (hasToolCalls) {
yield { data: json, chunk: '' };
return;
}
if (typeof json.candidates[0].content === 'object' && Array.isArray(json.candidates[i].content.parts)) {
for (let j = 0; j < json.candidates[i].content.parts.length; j++) {
if (typeof json.candidates[i].content.parts[j].text === 'string') {

View File

@@ -506,6 +506,26 @@ export class ToolManager {
}
}
}
if (Array.isArray(parsed?.candidates)) {
for (let choiceIndex = 0; choiceIndex < parsed.candidates.length; choiceIndex++) {
const candidate = parsed.candidates[choiceIndex];
if (Array.isArray(candidate?.content?.parts)) {
for (let toolCallIndex = 0; toolCallIndex < candidate.content.parts.length; toolCallIndex++) {
const part = candidate.content.parts[toolCallIndex];
if (part.functionCall) {
if (!Array.isArray(toolCalls[choiceIndex])) {
toolCalls[choiceIndex] = [];
}
if (toolCalls[choiceIndex][toolCallIndex] === undefined) {
toolCalls[choiceIndex][toolCallIndex] = {};
}
const targetToolCall = toolCalls[choiceIndex][toolCallIndex];
ToolManager.#applyToolCallDelta(targetToolCall, part.functionCall);
}
}
}
}
}
}
/**
@@ -564,6 +584,7 @@ export class ToolManager {
chat_completion_sources.GROQ,
chat_completion_sources.COHERE,
chat_completion_sources.DEEPSEEK,
chat_completion_sources.MAKERSUITE,
];
return supportedSources.includes(oai_settings.chat_completion_source);
}
@@ -585,8 +606,11 @@ export class ToolManager {
* @returns {any[]} Tool calls from the response data
*/
static #getToolCallsFromData(data) {
const getRandomId = () => Math.random().toString(36).substring(2);
const isClaudeToolCall = c => Array.isArray(c) ? c.filter(x => x).every(isClaudeToolCall) : c?.input && c?.name && c?.id;
const isGoogleToolCall = c => Array.isArray(c) ? c.filter(x => x).every(isGoogleToolCall) : c?.name && c?.args;
const convertClaudeToolCall = c => ({ id: c.id, function: { name: c.name, arguments: c.input } });
const convertGoogleToolCall = (c) => ({ id: getRandomId(), function: { name: c.name, arguments: c.args } });
// Parsed tool calls from streaming data
if (Array.isArray(data) && data.length > 0 && Array.isArray(data[0])) {
@@ -594,6 +618,10 @@ export class ToolManager {
return data[0].filter(x => x).map(convertClaudeToolCall);
}
if (isGoogleToolCall(data[0])) {
return data[0].filter(x => x).map(convertGoogleToolCall);
}
if (typeof data[0]?.[0]?.tool_calls === 'object') {
return Array.isArray(data[0]?.[0]?.tool_calls) ? data[0][0].tool_calls : [data[0][0].tool_calls];
}
@@ -601,6 +629,11 @@ export class ToolManager {
return data[0];
}
// Google AI Studio tool calls
if (Array.isArray(data?.responseContent?.parts)) {
return data.responseContent.parts.filter(p => p.functionCall).map(p => convertGoogleToolCall(p.functionCall));
}
// Parsed tool calls from non-streaming data
if (Array.isArray(data?.choices)) {
// Find a choice with 0-index

View File

@@ -385,6 +385,19 @@ async function sendMakerSuiteRequest(request, response) {
tools.push(searchTool);
}
if (Array.isArray(request.body.tools) && request.body.tools.length > 0) {
const functionDeclarations = [];
for (const tool of request.body.tools) {
if (tool.type === 'function') {
if (tool.function.parameters?.$schema) {
delete tool.function.parameters.$schema;
}
functionDeclarations.push(tool.function);
}
}
tools.push({ function_declarations: functionDeclarations });
}
let body = {
contents: prompt.contents,
safetySettings: safetySettings,
@@ -454,10 +467,11 @@ async function sendMakerSuiteRequest(request, response) {
}
const responseContent = candidates[0].content ?? candidates[0].output;
const functionCall = (candidates?.[0]?.content?.parts ?? []).some(part => part.functionCall);
console.warn('Google AI Studio response:', responseContent);
const responseText = typeof responseContent === 'string' ? responseContent : responseContent?.parts?.filter(part => !part.thought)?.map(part => part.text)?.join('\n\n');
if (!responseText) {
if (!responseText && !functionCall) {
let message = 'Google AI Studio Candidate text empty';
console.warn(message, generateResponseJson);
return response.send({ error: { message } });

View File

@@ -1,5 +1,5 @@
import crypto from 'node:crypto';
import { getConfigValue } from './util.js';
import { getConfigValue, tryParse } from './util.js';
const PROMPT_PLACEHOLDER = getConfigValue('promptPlaceholder', 'Let\'s get started.');
@@ -411,11 +411,12 @@ export function convertGooglePrompt(messages, model, useSysPrompt, names) {
}
const system_instruction = { parts: { text: sys_prompt.trim() } };
const toolNameMap = {};
const contents = [];
messages.forEach((message, index) => {
// fix the roles
if (message.role === 'system') {
if (message.role === 'system' || message.role === 'tool') {
message.role = 'user';
} else if (message.role === 'assistant') {
message.role = 'model';
@@ -423,7 +424,21 @@ export function convertGooglePrompt(messages, model, useSysPrompt, names) {
// Convert the content to an array of parts
if (!Array.isArray(message.content)) {
message.content = [{ type: 'text', text: String(message.content ?? '') }];
const content = (() => {
const hasToolCalls = Array.isArray(message.tool_calls) && message.tool_calls.length > 0;
const hasToolCallId = typeof message.tool_call_id === 'string' && message.tool_call_id.length > 0;
if (hasToolCalls) {
return { type: 'tool_calls', tool_calls: message.tool_calls };
}
if (hasToolCallId) {
return { type: 'tool_call_id', tool_call_id: message.tool_call_id, content: String(message.content ?? '') };
}
return { type: 'text', text: String(message.content ?? '') };
})();
message.content = [content];
}
// similar story as claude
@@ -455,6 +470,25 @@ export function convertGooglePrompt(messages, model, useSysPrompt, names) {
message.content.forEach((part) => {
if (part.type === 'text') {
parts.push({ text: part.text });
} else if (part.type === 'tool_call_id') {
const name = toolNameMap[part.tool_call_id] ?? 'unknown';
parts.push({
functionResponse: {
name: name,
response: { name: name, content: part.content },
},
});
} else if (part.type === 'tool_calls') {
part.tool_calls.forEach((toolCall) => {
parts.push({
functionCall: {
name: toolCall.function.name,
args: tryParse(toolCall.function.arguments) ?? toolCall.function.arguments,
},
});
toolNameMap[toolCall.id] = toolCall.function.name;
});
} else if (part.type === 'image_url' && isMultimodal) {
const mimeType = part.image_url.url.split(';')[0].split(':')[1];
const base64Data = part.image_url.url.split(',')[1];
@@ -473,7 +507,7 @@ export function convertGooglePrompt(messages, model, useSysPrompt, names) {
if (part.text) {
contents[contents.length - 1].parts[0].text += '\n\n' + part.text;
}
if (part.inlineData) {
if (part.inlineData || part.functionCall) {
contents[contents.length - 1].parts.push(part);
}
});