mirror of
				https://github.com/SillyTavern/SillyTavern.git
				synced 2025-06-05 21:59:27 +02:00 
			
		
		
		
	Gemini: Add tool calling
This commit is contained in:
		| @@ -1962,7 +1962,7 @@ | ||||
|                                             </span> | ||||
|                                         </div> | ||||
|                                     </div> | ||||
|                                     <div class="range-block" data-source="openai,cohere,mistralai,custom,claude,openrouter,groq,deepseek"> | ||||
|                                     <div class="range-block" data-source="openai,cohere,mistralai,custom,claude,openrouter,groq,deepseek,makersuite"> | ||||
|                                         <label for="openai_function_calling" class="checkbox_label flexWrap widthFreeExpand"> | ||||
|                                             <input id="openai_function_calling" type="checkbox" /> | ||||
|                                             <span data-i18n="Enable function calling">Enable function calling</span> | ||||
|   | ||||
| @@ -137,9 +137,14 @@ async function* parseStreamData(json) { | ||||
|     else if (Array.isArray(json.candidates)) { | ||||
|         for (let i = 0; i < json.candidates.length; i++) { | ||||
|             const isNotPrimary = json.candidates?.[0]?.index > 0; | ||||
|             const hasToolCalls = json?.candidates?.[0]?.content?.parts?.some(p => p?.functionCall); | ||||
|             if (isNotPrimary || json.candidates.length === 0) { | ||||
|                 return null; | ||||
|             } | ||||
|             if (hasToolCalls) { | ||||
|                 yield { data: json, chunk: '' }; | ||||
|                 return; | ||||
|             } | ||||
|             if (typeof json.candidates[0].content === 'object' && Array.isArray(json.candidates[i].content.parts)) { | ||||
|                 for (let j = 0; j < json.candidates[i].content.parts.length; j++) { | ||||
|                     if (typeof json.candidates[i].content.parts[j].text === 'string') { | ||||
|   | ||||
| @@ -506,6 +506,26 @@ export class ToolManager { | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         if (Array.isArray(parsed?.candidates)) { | ||||
|             for (let choiceIndex = 0; choiceIndex < parsed.candidates.length; choiceIndex++) { | ||||
|                 const candidate = parsed.candidates[choiceIndex]; | ||||
|                 if (Array.isArray(candidate?.content?.parts)) { | ||||
|                     for (let toolCallIndex = 0; toolCallIndex < candidate.content.parts.length; toolCallIndex++) { | ||||
|                         const part = candidate.content.parts[toolCallIndex]; | ||||
|                         if (part.functionCall) { | ||||
|                             if (!Array.isArray(toolCalls[choiceIndex])) { | ||||
|                                 toolCalls[choiceIndex] = []; | ||||
|                             } | ||||
|                             if (toolCalls[choiceIndex][toolCallIndex] === undefined) { | ||||
|                                 toolCalls[choiceIndex][toolCallIndex] = {}; | ||||
|                             } | ||||
|                             const targetToolCall = toolCalls[choiceIndex][toolCallIndex]; | ||||
|                             ToolManager.#applyToolCallDelta(targetToolCall, part.functionCall); | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     /** | ||||
| @@ -564,6 +584,7 @@ export class ToolManager { | ||||
|             chat_completion_sources.GROQ, | ||||
|             chat_completion_sources.COHERE, | ||||
|             chat_completion_sources.DEEPSEEK, | ||||
|             chat_completion_sources.MAKERSUITE, | ||||
|         ]; | ||||
|         return supportedSources.includes(oai_settings.chat_completion_source); | ||||
|     } | ||||
| @@ -585,8 +606,11 @@ export class ToolManager { | ||||
|      * @returns {any[]} Tool calls from the response data | ||||
|      */ | ||||
|     static #getToolCallsFromData(data) { | ||||
|         const getRandomId = () => Math.random().toString(36).substring(2); | ||||
|         const isClaudeToolCall = c => Array.isArray(c) ? c.filter(x => x).every(isClaudeToolCall) : c?.input && c?.name && c?.id; | ||||
|         const isGoogleToolCall = c => Array.isArray(c) ? c.filter(x => x).every(isGoogleToolCall) : c?.name && c?.args; | ||||
|         const convertClaudeToolCall = c => ({ id: c.id, function: { name: c.name, arguments: c.input } }); | ||||
|         const convertGoogleToolCall = (c) => ({ id: getRandomId(), function: { name: c.name, arguments: c.args } }); | ||||
|  | ||||
|         // Parsed tool calls from streaming data | ||||
|         if (Array.isArray(data) && data.length > 0 && Array.isArray(data[0])) { | ||||
| @@ -594,6 +618,10 @@ export class ToolManager { | ||||
|                 return data[0].filter(x => x).map(convertClaudeToolCall); | ||||
|             } | ||||
|  | ||||
|             if (isGoogleToolCall(data[0])) { | ||||
|                 return data[0].filter(x => x).map(convertGoogleToolCall); | ||||
|             } | ||||
|  | ||||
|             if (typeof data[0]?.[0]?.tool_calls === 'object') { | ||||
|                 return Array.isArray(data[0]?.[0]?.tool_calls) ? data[0][0].tool_calls : [data[0][0].tool_calls]; | ||||
|             } | ||||
| @@ -601,6 +629,11 @@ export class ToolManager { | ||||
|             return data[0]; | ||||
|         } | ||||
|  | ||||
|         // Google AI Studio tool calls | ||||
|         if (Array.isArray(data?.responseContent?.parts)) { | ||||
|             return data.responseContent.parts.filter(p => p.functionCall).map(p => convertGoogleToolCall(p.functionCall)); | ||||
|         } | ||||
|  | ||||
|         // Parsed tool calls from non-streaming data | ||||
|         if (Array.isArray(data?.choices)) { | ||||
|             // Find a choice with 0-index | ||||
|   | ||||
| @@ -385,6 +385,19 @@ async function sendMakerSuiteRequest(request, response) { | ||||
|             tools.push(searchTool); | ||||
|         } | ||||
|  | ||||
|         if (Array.isArray(request.body.tools) && request.body.tools.length > 0) { | ||||
|             const functionDeclarations = []; | ||||
|             for (const tool of request.body.tools) { | ||||
|                 if (tool.type === 'function') { | ||||
|                     if (tool.function.parameters?.$schema) { | ||||
|                         delete tool.function.parameters.$schema; | ||||
|                     } | ||||
|                     functionDeclarations.push(tool.function); | ||||
|                 } | ||||
|             } | ||||
|             tools.push({ function_declarations: functionDeclarations }); | ||||
|         } | ||||
|  | ||||
|         let body = { | ||||
|             contents: prompt.contents, | ||||
|             safetySettings: safetySettings, | ||||
| @@ -454,10 +467,11 @@ async function sendMakerSuiteRequest(request, response) { | ||||
|             } | ||||
|  | ||||
|             const responseContent = candidates[0].content ?? candidates[0].output; | ||||
|             const functionCall = (candidates?.[0]?.content?.parts ?? []).some(part => part.functionCall); | ||||
|             console.warn('Google AI Studio response:', responseContent); | ||||
|  | ||||
|             const responseText = typeof responseContent === 'string' ? responseContent : responseContent?.parts?.filter(part => !part.thought)?.map(part => part.text)?.join('\n\n'); | ||||
|             if (!responseText) { | ||||
|             if (!responseText && !functionCall) { | ||||
|                 let message = 'Google AI Studio Candidate text empty'; | ||||
|                 console.warn(message, generateResponseJson); | ||||
|                 return response.send({ error: { message } }); | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| import crypto from 'node:crypto'; | ||||
| import { getConfigValue } from './util.js'; | ||||
| import { getConfigValue, tryParse } from './util.js'; | ||||
|  | ||||
| const PROMPT_PLACEHOLDER = getConfigValue('promptPlaceholder', 'Let\'s get started.'); | ||||
|  | ||||
| @@ -411,11 +411,12 @@ export function convertGooglePrompt(messages, model, useSysPrompt, names) { | ||||
|     } | ||||
|  | ||||
|     const system_instruction = { parts: { text: sys_prompt.trim() } }; | ||||
|     const toolNameMap = {}; | ||||
|  | ||||
|     const contents = []; | ||||
|     messages.forEach((message, index) => { | ||||
|         // fix the roles | ||||
|         if (message.role === 'system') { | ||||
|         if (message.role === 'system' || message.role === 'tool') { | ||||
|             message.role = 'user'; | ||||
|         } else if (message.role === 'assistant') { | ||||
|             message.role = 'model'; | ||||
| @@ -423,7 +424,21 @@ export function convertGooglePrompt(messages, model, useSysPrompt, names) { | ||||
|  | ||||
|         // Convert the content to an array of parts | ||||
|         if (!Array.isArray(message.content)) { | ||||
|             message.content = [{ type: 'text', text: String(message.content ?? '') }]; | ||||
|             const content = (() => { | ||||
|                 const hasToolCalls = Array.isArray(message.tool_calls) && message.tool_calls.length > 0; | ||||
|                 const hasToolCallId = typeof message.tool_call_id === 'string' && message.tool_call_id.length > 0; | ||||
|  | ||||
|                 if (hasToolCalls) { | ||||
|                     return { type: 'tool_calls', tool_calls: message.tool_calls }; | ||||
|                 } | ||||
|  | ||||
|                 if (hasToolCallId) { | ||||
|                     return { type: 'tool_call_id', tool_call_id: message.tool_call_id, content: String(message.content ?? '') }; | ||||
|                 } | ||||
|  | ||||
|                 return { type: 'text', text: String(message.content ?? '') }; | ||||
|             })(); | ||||
|             message.content = [content]; | ||||
|         } | ||||
|  | ||||
|         // similar story as claude | ||||
| @@ -455,6 +470,25 @@ export function convertGooglePrompt(messages, model, useSysPrompt, names) { | ||||
|         message.content.forEach((part) => { | ||||
|             if (part.type === 'text') { | ||||
|                 parts.push({ text: part.text }); | ||||
|             } else if (part.type === 'tool_call_id') { | ||||
|                 const name = toolNameMap[part.tool_call_id] ?? 'unknown'; | ||||
|                 parts.push({ | ||||
|                     functionResponse: { | ||||
|                         name: name, | ||||
|                         response: { name: name, content: part.content }, | ||||
|                     }, | ||||
|                 }); | ||||
|             } else if (part.type === 'tool_calls') { | ||||
|                 part.tool_calls.forEach((toolCall) => { | ||||
|                     parts.push({ | ||||
|                         functionCall: { | ||||
|                             name: toolCall.function.name, | ||||
|                             args: tryParse(toolCall.function.arguments) ?? toolCall.function.arguments, | ||||
|                         }, | ||||
|                     }); | ||||
|  | ||||
|                     toolNameMap[toolCall.id] = toolCall.function.name; | ||||
|                 }); | ||||
|             } else if (part.type === 'image_url' && isMultimodal) { | ||||
|                 const mimeType = part.image_url.url.split(';')[0].split(':')[1]; | ||||
|                 const base64Data = part.image_url.url.split(',')[1]; | ||||
| @@ -473,7 +507,7 @@ export function convertGooglePrompt(messages, model, useSysPrompt, names) { | ||||
|                 if (part.text) { | ||||
|                     contents[contents.length - 1].parts[0].text += '\n\n' + part.text; | ||||
|                 } | ||||
|                 if (part.inlineData) { | ||||
|                 if (part.inlineData || part.functionCall) { | ||||
|                     contents[contents.length - 1].parts.push(part); | ||||
|                 } | ||||
|             }); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user