New tool calling framework

This commit is contained in:
Cohee
2024-10-02 01:00:48 +03:00
parent b3e88c82b8
commit 8006795897
6 changed files with 427 additions and 253 deletions

View File

@ -246,6 +246,7 @@ import { initInputMarkdown } from './scripts/input-md-formatting.js';
import { AbortReason } from './scripts/util/AbortReason.js';
import { initSystemPrompts } from './scripts/sysprompt.js';
import { registerExtensionSlashCommands as initExtensionSlashCommands } from './scripts/extensions-slashcommands.js';
import { ToolManager } from './scripts/tool-calling.js';
//exporting functions and vars for mods
export {
@ -463,8 +464,6 @@ export const event_types = {
FILE_ATTACHMENT_DELETED: 'file_attachment_deleted',
WORLDINFO_FORCE_ACTIVATE: 'worldinfo_force_activate',
OPEN_CHARACTER_LIBRARY: 'open_character_library',
LLM_FUNCTION_TOOL_REGISTER: 'llm_function_tool_register',
LLM_FUNCTION_TOOL_CALL: 'llm_function_tool_call',
ONLINE_STATUS_CHANGED: 'online_status_changed',
IMAGE_SWIPED: 'image_swiped',
CONNECTION_PROFILE_LOADED: 'connection_profile_loaded',
@ -2921,6 +2920,7 @@ class StreamingProcessor {
this.swipes = [];
/** @type {import('./scripts/logprobs.js').TokenLogprobs[]} */
this.messageLogprobs = [];
this.toolCalls = [];
}
#checkDomElements(messageId) {
@ -3139,7 +3139,7 @@ class StreamingProcessor {
}
/**
* @returns {Generator<{ text: string, swipes: string[], logprobs: import('./scripts/logprobs.js').TokenLogprobs }, void, void>}
* @returns {Generator<{ text: string, swipes: string[], logprobs: import('./scripts/logprobs.js').TokenLogprobs, toolCalls: any[] }, void, void>}
*/
*nullStreamingGeneration() {
throw new Error('Generation function for streaming is not hooked up');
@ -3161,12 +3161,13 @@ class StreamingProcessor {
try {
const sw = new Stopwatch(1000 / power_user.streaming_fps);
const timestamps = [];
for await (const { text, swipes, logprobs } of this.generator()) {
for await (const { text, swipes, logprobs, toolCalls } of this.generator()) {
timestamps.push(Date.now());
if (this.isStopped) {
return;
}
this.toolCalls = toolCalls;
this.result = text;
this.swipes = Array.from(swipes ?? []);
if (logprobs) {
@ -4405,6 +4406,20 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
getMessage = continue_mag + getMessage;
}
if (ToolManager.isFunctionCallingSupported() && Array.isArray(streamingProcessor.toolCalls) && streamingProcessor.toolCalls.length) {
const invocations = await ToolManager.checkFunctionToolCalls(streamingProcessor.toolCalls);
if (invocations.length) {
const lastMessage = chat[chat.length - 1];
const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result);
if (shouldDeleteMessage) {
await deleteLastMessage();
streamingProcessor = null;
}
ToolManager.saveFunctionToolInvocations(invocations);
return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun);
}
}
if (streamingProcessor && !streamingProcessor.isStopped && streamingProcessor.isFinished) {
await streamingProcessor.onFinishStreaming(streamingProcessor.messageId, getMessage);
streamingProcessor = null;
@ -4440,6 +4455,14 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
throw new Error(data?.response);
}
if (ToolManager.isFunctionCallingSupported()) {
const invocations = await ToolManager.checkFunctionToolCalls(data);
if (invocations.length) {
ToolManager.saveFunctionToolInvocations(invocations);
return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun);
}
}
//const getData = await response.json();
let getMessage = extractMessageFromData(data);
let title = extractTitleFromData(data);
@ -7853,7 +7876,7 @@ function openAlternateGreetings() {
if (menu_type !== 'create') {
await createOrEditCharacter();
}
}
},
});
for (let index = 0; index < getArray().length; index++) {
@ -8130,6 +8153,8 @@ window['SillyTavern'].getContext = function () {
registerHelper: () => { },
registerMacro: MacrosParser.registerMacro.bind(MacrosParser),
unregisterMacro: MacrosParser.unregisterMacro.bind(MacrosParser),
registerFunctionTool: ToolManager.registerFunctionTool.bind(ToolManager),
unregisterFunctionTool: ToolManager.unregisterFunctionTool.bind(ToolManager),
registerDebugFunction: registerDebugFunction,
/** @deprecated Use renderExtensionTemplateAsync instead. */
renderExtensionTemplate: renderExtensionTemplate,