Limit tool calls recursion

This commit is contained in:
Cohee
2024-10-07 00:22:27 +03:00
parent 67ebf0fc06
commit 7db85e7ed8
2 changed files with 13 additions and 5 deletions

View File

@ -3409,9 +3409,9 @@ function removeLastMessage() {
* @param {GenerateOptions} options Generation options * @param {GenerateOptions} options Generation options
* @param {boolean} dryRun Whether to actually generate a message or just assemble the prompt * @param {boolean} dryRun Whether to actually generate a message or just assemble the prompt
* @returns {Promise<any>} Returns a promise that resolves when the text is done generating. * @returns {Promise<any>} Returns a promise that resolves when the text is done generating.
* @typedef {{automatic_trigger?: boolean, force_name2?: boolean, quiet_prompt?: string, quietToLoud?: boolean, skipWIAN?: boolean, force_chid?: number, signal?: AbortSignal, quietImage?: string, quietName?: string }} GenerateOptions * @typedef {{automatic_trigger?: boolean, force_name2?: boolean, quiet_prompt?: string, quietToLoud?: boolean, skipWIAN?: boolean, force_chid?: number, signal?: AbortSignal, quietImage?: string, quietName?: string, depth?: number }} GenerateOptions
*/ */
export async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName } = {}, dryRun = false) { export async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName, depth = 0 } = {}, dryRun = false) {
console.log('Generate entered'); console.log('Generate entered');
setGenerationProgress(0); setGenerationProgress(0);
generation_started = new Date(); generation_started = new Date();
@ -3631,7 +3631,7 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
// Collect messages with usable content // Collect messages with usable content
const canUseTools = ToolManager.isToolCallingSupported(); const canUseTools = ToolManager.isToolCallingSupported();
const canPerformToolCalls = !dryRun && ToolManager.canPerformToolCalls(type); const canPerformToolCalls = !dryRun && ToolManager.canPerformToolCalls(type) && depth < ToolManager.RECURSE_LIMIT;
let coreChat = chat.filter(x => !x.is_system || (canUseTools && Array.isArray(x.extra?.tool_invocations))); let coreChat = chat.filter(x => !x.is_system || (canUseTools && Array.isArray(x.extra?.tool_invocations)));
if (type === 'swipe') { if (type === 'swipe') {
coreChat.pop(); coreChat.pop();
@ -4485,8 +4485,9 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
} }
streamingProcessor = null; streamingProcessor = null;
depth = depth + 1;
await ToolManager.saveFunctionToolInvocations(invocationResult.invocations); await ToolManager.saveFunctionToolInvocations(invocationResult.invocations);
return Generate('normal', { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun); return Generate('normal', { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName, depth }, dryRun);
} }
} }
@ -4577,8 +4578,9 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
return; return;
} }
depth = depth + 1;
await ToolManager.saveFunctionToolInvocations(invocationResult.invocations); await ToolManager.saveFunctionToolInvocations(invocationResult.invocations);
return Generate('normal', { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun); return Generate('normal', { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName, depth }, dryRun);
} }
} }

View File

@ -210,6 +210,12 @@ export class ToolManager {
static #INPUT_DELTA_KEY = '__input_json_delta'; static #INPUT_DELTA_KEY = '__input_json_delta';
/**
* The maximum number of times to recurse when parsing tool calls.
* @type {number}
*/
static RECURSE_LIMIT = 5;
/** /**
* Returns an Array of all tools that have been registered. * Returns an Array of all tools that have been registered.
* @type {ToolDefinition[]} * @type {ToolDefinition[]}