New tool calling framework
This commit is contained in:
parent
b3e88c82b8
commit
8006795897
|
@ -1365,44 +1365,3 @@ declare namespace moment {
|
||||||
declare global {
|
declare global {
|
||||||
const moment: typeof moment;
|
const moment: typeof moment;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Callback data for the `LLM_FUNCTION_TOOL_REGISTER` event type that is triggered when a function tool can be registered.
|
|
||||||
*/
|
|
||||||
interface FunctionToolRegister {
|
|
||||||
/**
|
|
||||||
* The type of generation that is being used
|
|
||||||
*/
|
|
||||||
type?: string;
|
|
||||||
/**
|
|
||||||
* Generation data, including messages and sampling parameters
|
|
||||||
*/
|
|
||||||
data: Record<string, object>;
|
|
||||||
/**
|
|
||||||
* Callback to register an LLM function tool.
|
|
||||||
*/
|
|
||||||
registerFunctionTool: typeof registerFunctionTool;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Callback data for the `LLM_FUNCTION_TOOL_REGISTER` event type that is triggered when a function tool is registered.
|
|
||||||
* @param name Name of the function tool to register
|
|
||||||
* @param description Description of the function tool
|
|
||||||
* @param params JSON schema for the parameters of the function tool
|
|
||||||
* @param required Whether the function tool should be forced to be used
|
|
||||||
*/
|
|
||||||
declare function registerFunctionTool(name: string, description: string, params: object, required: boolean): Promise<void>;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Callback data for the `LLM_FUNCTION_TOOL_CALL` event type that is triggered when a function tool is called.
|
|
||||||
*/
|
|
||||||
interface FunctionToolCall {
|
|
||||||
/**
|
|
||||||
* Name of the function tool to call
|
|
||||||
*/
|
|
||||||
name: string;
|
|
||||||
/**
|
|
||||||
* JSON object with the parameters to pass to the function tool
|
|
||||||
*/
|
|
||||||
arguments: string;
|
|
||||||
}
|
|
||||||
|
|
|
@ -246,6 +246,7 @@ import { initInputMarkdown } from './scripts/input-md-formatting.js';
|
||||||
import { AbortReason } from './scripts/util/AbortReason.js';
|
import { AbortReason } from './scripts/util/AbortReason.js';
|
||||||
import { initSystemPrompts } from './scripts/sysprompt.js';
|
import { initSystemPrompts } from './scripts/sysprompt.js';
|
||||||
import { registerExtensionSlashCommands as initExtensionSlashCommands } from './scripts/extensions-slashcommands.js';
|
import { registerExtensionSlashCommands as initExtensionSlashCommands } from './scripts/extensions-slashcommands.js';
|
||||||
|
import { ToolManager } from './scripts/tool-calling.js';
|
||||||
|
|
||||||
//exporting functions and vars for mods
|
//exporting functions and vars for mods
|
||||||
export {
|
export {
|
||||||
|
@ -463,8 +464,6 @@ export const event_types = {
|
||||||
FILE_ATTACHMENT_DELETED: 'file_attachment_deleted',
|
FILE_ATTACHMENT_DELETED: 'file_attachment_deleted',
|
||||||
WORLDINFO_FORCE_ACTIVATE: 'worldinfo_force_activate',
|
WORLDINFO_FORCE_ACTIVATE: 'worldinfo_force_activate',
|
||||||
OPEN_CHARACTER_LIBRARY: 'open_character_library',
|
OPEN_CHARACTER_LIBRARY: 'open_character_library',
|
||||||
LLM_FUNCTION_TOOL_REGISTER: 'llm_function_tool_register',
|
|
||||||
LLM_FUNCTION_TOOL_CALL: 'llm_function_tool_call',
|
|
||||||
ONLINE_STATUS_CHANGED: 'online_status_changed',
|
ONLINE_STATUS_CHANGED: 'online_status_changed',
|
||||||
IMAGE_SWIPED: 'image_swiped',
|
IMAGE_SWIPED: 'image_swiped',
|
||||||
CONNECTION_PROFILE_LOADED: 'connection_profile_loaded',
|
CONNECTION_PROFILE_LOADED: 'connection_profile_loaded',
|
||||||
|
@ -2921,6 +2920,7 @@ class StreamingProcessor {
|
||||||
this.swipes = [];
|
this.swipes = [];
|
||||||
/** @type {import('./scripts/logprobs.js').TokenLogprobs[]} */
|
/** @type {import('./scripts/logprobs.js').TokenLogprobs[]} */
|
||||||
this.messageLogprobs = [];
|
this.messageLogprobs = [];
|
||||||
|
this.toolCalls = [];
|
||||||
}
|
}
|
||||||
|
|
||||||
#checkDomElements(messageId) {
|
#checkDomElements(messageId) {
|
||||||
|
@ -3139,7 +3139,7 @@ class StreamingProcessor {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @returns {Generator<{ text: string, swipes: string[], logprobs: import('./scripts/logprobs.js').TokenLogprobs }, void, void>}
|
* @returns {Generator<{ text: string, swipes: string[], logprobs: import('./scripts/logprobs.js').TokenLogprobs, toolCalls: any[] }, void, void>}
|
||||||
*/
|
*/
|
||||||
*nullStreamingGeneration() {
|
*nullStreamingGeneration() {
|
||||||
throw new Error('Generation function for streaming is not hooked up');
|
throw new Error('Generation function for streaming is not hooked up');
|
||||||
|
@ -3161,12 +3161,13 @@ class StreamingProcessor {
|
||||||
try {
|
try {
|
||||||
const sw = new Stopwatch(1000 / power_user.streaming_fps);
|
const sw = new Stopwatch(1000 / power_user.streaming_fps);
|
||||||
const timestamps = [];
|
const timestamps = [];
|
||||||
for await (const { text, swipes, logprobs } of this.generator()) {
|
for await (const { text, swipes, logprobs, toolCalls } of this.generator()) {
|
||||||
timestamps.push(Date.now());
|
timestamps.push(Date.now());
|
||||||
if (this.isStopped) {
|
if (this.isStopped) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
this.toolCalls = toolCalls;
|
||||||
this.result = text;
|
this.result = text;
|
||||||
this.swipes = Array.from(swipes ?? []);
|
this.swipes = Array.from(swipes ?? []);
|
||||||
if (logprobs) {
|
if (logprobs) {
|
||||||
|
@ -4405,6 +4406,20 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
|
||||||
getMessage = continue_mag + getMessage;
|
getMessage = continue_mag + getMessage;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (ToolManager.isFunctionCallingSupported() && Array.isArray(streamingProcessor.toolCalls) && streamingProcessor.toolCalls.length) {
|
||||||
|
const invocations = await ToolManager.checkFunctionToolCalls(streamingProcessor.toolCalls);
|
||||||
|
if (invocations.length) {
|
||||||
|
const lastMessage = chat[chat.length - 1];
|
||||||
|
const shouldDeleteMessage = ['', '...'].includes(lastMessage?.mes) && ['', '...'].includes(streamingProcessor.result);
|
||||||
|
if (shouldDeleteMessage) {
|
||||||
|
await deleteLastMessage();
|
||||||
|
streamingProcessor = null;
|
||||||
|
}
|
||||||
|
ToolManager.saveFunctionToolInvocations(invocations);
|
||||||
|
return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (streamingProcessor && !streamingProcessor.isStopped && streamingProcessor.isFinished) {
|
if (streamingProcessor && !streamingProcessor.isStopped && streamingProcessor.isFinished) {
|
||||||
await streamingProcessor.onFinishStreaming(streamingProcessor.messageId, getMessage);
|
await streamingProcessor.onFinishStreaming(streamingProcessor.messageId, getMessage);
|
||||||
streamingProcessor = null;
|
streamingProcessor = null;
|
||||||
|
@ -4440,6 +4455,14 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
|
||||||
throw new Error(data?.response);
|
throw new Error(data?.response);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (ToolManager.isFunctionCallingSupported()) {
|
||||||
|
const invocations = await ToolManager.checkFunctionToolCalls(data);
|
||||||
|
if (invocations.length) {
|
||||||
|
ToolManager.saveFunctionToolInvocations(invocations);
|
||||||
|
return Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName }, dryRun);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//const getData = await response.json();
|
//const getData = await response.json();
|
||||||
let getMessage = extractMessageFromData(data);
|
let getMessage = extractMessageFromData(data);
|
||||||
let title = extractTitleFromData(data);
|
let title = extractTitleFromData(data);
|
||||||
|
@ -7853,7 +7876,7 @@ function openAlternateGreetings() {
|
||||||
if (menu_type !== 'create') {
|
if (menu_type !== 'create') {
|
||||||
await createOrEditCharacter();
|
await createOrEditCharacter();
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
for (let index = 0; index < getArray().length; index++) {
|
for (let index = 0; index < getArray().length; index++) {
|
||||||
|
@ -8130,6 +8153,8 @@ window['SillyTavern'].getContext = function () {
|
||||||
registerHelper: () => { },
|
registerHelper: () => { },
|
||||||
registerMacro: MacrosParser.registerMacro.bind(MacrosParser),
|
registerMacro: MacrosParser.registerMacro.bind(MacrosParser),
|
||||||
unregisterMacro: MacrosParser.unregisterMacro.bind(MacrosParser),
|
unregisterMacro: MacrosParser.unregisterMacro.bind(MacrosParser),
|
||||||
|
registerFunctionTool: ToolManager.registerFunctionTool.bind(ToolManager),
|
||||||
|
unregisterFunctionTool: ToolManager.unregisterFunctionTool.bind(ToolManager),
|
||||||
registerDebugFunction: registerDebugFunction,
|
registerDebugFunction: registerDebugFunction,
|
||||||
/** @deprecated Use renderExtensionTemplateAsync instead. */
|
/** @deprecated Use renderExtensionTemplateAsync instead. */
|
||||||
renderExtensionTemplate: renderExtensionTemplate,
|
renderExtensionTemplate: renderExtensionTemplate,
|
||||||
|
|
|
@ -9,7 +9,6 @@ import { debounce_timeout } from '../../constants.js';
|
||||||
import { SlashCommandParser } from '../../slash-commands/SlashCommandParser.js';
|
import { SlashCommandParser } from '../../slash-commands/SlashCommandParser.js';
|
||||||
import { SlashCommand } from '../../slash-commands/SlashCommand.js';
|
import { SlashCommand } from '../../slash-commands/SlashCommand.js';
|
||||||
import { ARGUMENT_TYPE, SlashCommandArgument, SlashCommandNamedArgument } from '../../slash-commands/SlashCommandArgument.js';
|
import { ARGUMENT_TYPE, SlashCommandArgument, SlashCommandNamedArgument } from '../../slash-commands/SlashCommandArgument.js';
|
||||||
import { isFunctionCallingSupported } from '../../openai.js';
|
|
||||||
import { SlashCommandEnumValue, enumTypes } from '../../slash-commands/SlashCommandEnumValue.js';
|
import { SlashCommandEnumValue, enumTypes } from '../../slash-commands/SlashCommandEnumValue.js';
|
||||||
import { commonEnumProviders } from '../../slash-commands/SlashCommandCommonEnumsProvider.js';
|
import { commonEnumProviders } from '../../slash-commands/SlashCommandCommonEnumsProvider.js';
|
||||||
import { slashCommandReturnHelper } from '../../slash-commands/SlashCommandReturnHelper.js';
|
import { slashCommandReturnHelper } from '../../slash-commands/SlashCommandReturnHelper.js';
|
||||||
|
@ -21,7 +20,6 @@ const UPDATE_INTERVAL = 2000;
|
||||||
const STREAMING_UPDATE_INTERVAL = 10000;
|
const STREAMING_UPDATE_INTERVAL = 10000;
|
||||||
const TALKINGCHECK_UPDATE_INTERVAL = 500;
|
const TALKINGCHECK_UPDATE_INTERVAL = 500;
|
||||||
const DEFAULT_FALLBACK_EXPRESSION = 'joy';
|
const DEFAULT_FALLBACK_EXPRESSION = 'joy';
|
||||||
const FUNCTION_NAME = 'set_emotion';
|
|
||||||
const DEFAULT_LLM_PROMPT = 'Ignore previous instructions. Classify the emotion of the last message. Output just one word, e.g. "joy" or "anger". Choose only one of the following labels: {{labels}}';
|
const DEFAULT_LLM_PROMPT = 'Ignore previous instructions. Classify the emotion of the last message. Output just one word, e.g. "joy" or "anger". Choose only one of the following labels: {{labels}}';
|
||||||
const DEFAULT_EXPRESSIONS = [
|
const DEFAULT_EXPRESSIONS = [
|
||||||
'talkinghead',
|
'talkinghead',
|
||||||
|
@ -1017,10 +1015,6 @@ async function getLlmPrompt(labels) {
|
||||||
return '';
|
return '';
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isFunctionCallingSupported()) {
|
|
||||||
return '';
|
|
||||||
}
|
|
||||||
|
|
||||||
const labelsString = labels.map(x => `"${x}"`).join(', ');
|
const labelsString = labels.map(x => `"${x}"`).join(', ');
|
||||||
const prompt = substituteParamsExtended(String(extension_settings.expressions.llmPrompt), { labels: labelsString });
|
const prompt = substituteParamsExtended(String(extension_settings.expressions.llmPrompt), { labels: labelsString });
|
||||||
return prompt;
|
return prompt;
|
||||||
|
@ -1056,41 +1050,6 @@ function parseLlmResponse(emotionResponse, labels) {
|
||||||
throw new Error('Could not parse emotion response ' + emotionResponse);
|
throw new Error('Could not parse emotion response ' + emotionResponse);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Registers the function tool for the LLM API.
|
|
||||||
* @param {FunctionToolRegister} args Function tool register arguments.
|
|
||||||
*/
|
|
||||||
function onFunctionToolRegister(args) {
|
|
||||||
if (inApiCall && extension_settings.expressions.api === EXPRESSION_API.llm && isFunctionCallingSupported()) {
|
|
||||||
// Only trigger on quiet mode
|
|
||||||
if (args.type !== 'quiet') {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const emotions = DEFAULT_EXPRESSIONS.filter((e) => e != 'talkinghead');
|
|
||||||
const jsonSchema = {
|
|
||||||
$schema: 'http://json-schema.org/draft-04/schema#',
|
|
||||||
type: 'object',
|
|
||||||
properties: {
|
|
||||||
emotion: {
|
|
||||||
type: 'string',
|
|
||||||
enum: emotions,
|
|
||||||
description: `One of the following: ${JSON.stringify(emotions)}`,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
required: [
|
|
||||||
'emotion',
|
|
||||||
],
|
|
||||||
};
|
|
||||||
args.registerFunctionTool(
|
|
||||||
FUNCTION_NAME,
|
|
||||||
substituteParams('Sets the label that best describes the current emotional state of {{char}}. Only select one of the enumerated values.'),
|
|
||||||
jsonSchema,
|
|
||||||
true,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function onTextGenSettingsReady(args) {
|
function onTextGenSettingsReady(args) {
|
||||||
// Only call if inside an API call
|
// Only call if inside an API call
|
||||||
if (inApiCall && extension_settings.expressions.api === EXPRESSION_API.llm && isJsonSchemaSupported()) {
|
if (inApiCall && extension_settings.expressions.api === EXPRESSION_API.llm && isJsonSchemaSupported()) {
|
||||||
|
@ -1164,18 +1123,9 @@ export async function getExpressionLabel(text, expressionsApi = extension_settin
|
||||||
|
|
||||||
const expressionsList = await getExpressionsList();
|
const expressionsList = await getExpressionsList();
|
||||||
const prompt = substituteParamsExtended(customPrompt, { labels: expressionsList }) || await getLlmPrompt(expressionsList);
|
const prompt = substituteParamsExtended(customPrompt, { labels: expressionsList }) || await getLlmPrompt(expressionsList);
|
||||||
let functionResult = null;
|
|
||||||
eventSource.once(event_types.TEXT_COMPLETION_SETTINGS_READY, onTextGenSettingsReady);
|
eventSource.once(event_types.TEXT_COMPLETION_SETTINGS_READY, onTextGenSettingsReady);
|
||||||
eventSource.once(event_types.LLM_FUNCTION_TOOL_REGISTER, onFunctionToolRegister);
|
|
||||||
eventSource.once(event_types.LLM_FUNCTION_TOOL_CALL, (/** @type {FunctionToolCall} */ args) => {
|
|
||||||
if (args.name !== FUNCTION_NAME) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
functionResult = args?.arguments;
|
|
||||||
});
|
|
||||||
const emotionResponse = await generateRaw(text, main_api, false, false, prompt);
|
const emotionResponse = await generateRaw(text, main_api, false, false, prompt);
|
||||||
return parseLlmResponse(functionResult || emotionResponse, expressionsList);
|
return parseLlmResponse(emotionResponse, expressionsList);
|
||||||
}
|
}
|
||||||
// Extras
|
// Extras
|
||||||
default: {
|
default: {
|
||||||
|
|
|
@ -70,6 +70,7 @@ import { renderTemplateAsync } from './templates.js';
|
||||||
import { SlashCommandEnumValue } from './slash-commands/SlashCommandEnumValue.js';
|
import { SlashCommandEnumValue } from './slash-commands/SlashCommandEnumValue.js';
|
||||||
import { Popup, POPUP_RESULT } from './popup.js';
|
import { Popup, POPUP_RESULT } from './popup.js';
|
||||||
import { t } from './i18n.js';
|
import { t } from './i18n.js';
|
||||||
|
import { ToolManager } from './tool-calling.js';
|
||||||
|
|
||||||
export {
|
export {
|
||||||
openai_messages_count,
|
openai_messages_count,
|
||||||
|
@ -1863,8 +1864,8 @@ async function sendOpenAIRequest(type, messages, signal) {
|
||||||
generate_data['seed'] = oai_settings.seed;
|
generate_data['seed'] = oai_settings.seed;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isFunctionCallingSupported() && !stream) {
|
if (!canMultiSwipe && ToolManager.isFunctionCallingSupported()) {
|
||||||
await registerFunctionTools(type, generate_data);
|
await ToolManager.registerFunctionToolsOpenAI(generate_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isOAI && oai_settings.openai_model.startsWith('o1-')) {
|
if (isOAI && oai_settings.openai_model.startsWith('o1-')) {
|
||||||
|
@ -1911,6 +1912,7 @@ async function sendOpenAIRequest(type, messages, signal) {
|
||||||
return async function* streamData() {
|
return async function* streamData() {
|
||||||
let text = '';
|
let text = '';
|
||||||
const swipes = [];
|
const swipes = [];
|
||||||
|
const toolCalls = [];
|
||||||
while (true) {
|
while (true) {
|
||||||
const { done, value } = await reader.read();
|
const { done, value } = await reader.read();
|
||||||
if (done) return;
|
if (done) return;
|
||||||
|
@ -1926,7 +1928,9 @@ async function sendOpenAIRequest(type, messages, signal) {
|
||||||
text += getStreamingReply(parsed);
|
text += getStreamingReply(parsed);
|
||||||
}
|
}
|
||||||
|
|
||||||
yield { text, swipes: swipes, logprobs: parseChatCompletionLogprobs(parsed) };
|
ToolManager.parseToolCalls(toolCalls, parsed);
|
||||||
|
|
||||||
|
yield { text, swipes: swipes, logprobs: parseChatCompletionLogprobs(parsed), toolCalls: toolCalls };
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -1948,147 +1952,10 @@ async function sendOpenAIRequest(type, messages, signal) {
|
||||||
delay(1).then(() => saveLogprobsForActiveMessage(logprobs, null));
|
delay(1).then(() => saveLogprobsForActiveMessage(logprobs, null));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isFunctionCallingSupported()) {
|
|
||||||
await checkFunctionToolCalls(data);
|
|
||||||
}
|
|
||||||
|
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Register function tools for the next chat completion request.
|
|
||||||
* @param {string} type Generation type
|
|
||||||
* @param {object} data Generation data
|
|
||||||
*/
|
|
||||||
async function registerFunctionTools(type, data) {
|
|
||||||
let toolChoice = 'auto';
|
|
||||||
const tools = [];
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @type {registerFunctionTool}
|
|
||||||
*/
|
|
||||||
const registerFunctionTool = (name, description, parameters, required) => {
|
|
||||||
tools.push({
|
|
||||||
type: 'function',
|
|
||||||
function: {
|
|
||||||
name,
|
|
||||||
description,
|
|
||||||
parameters,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
if (required) {
|
|
||||||
toolChoice = 'required';
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @type {FunctionToolRegister}
|
|
||||||
*/
|
|
||||||
const args = {
|
|
||||||
type,
|
|
||||||
data,
|
|
||||||
registerFunctionTool,
|
|
||||||
};
|
|
||||||
|
|
||||||
await eventSource.emit(event_types.LLM_FUNCTION_TOOL_REGISTER, args);
|
|
||||||
|
|
||||||
if (tools.length) {
|
|
||||||
console.log('Registered function tools:', tools);
|
|
||||||
|
|
||||||
data['tools'] = tools;
|
|
||||||
data['tool_choice'] = toolChoice;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async function checkFunctionToolCalls(data) {
|
|
||||||
const oaiCompat = [
|
|
||||||
chat_completion_sources.OPENAI,
|
|
||||||
chat_completion_sources.CUSTOM,
|
|
||||||
chat_completion_sources.MISTRALAI,
|
|
||||||
chat_completion_sources.OPENROUTER,
|
|
||||||
chat_completion_sources.GROQ,
|
|
||||||
];
|
|
||||||
if (oaiCompat.includes(oai_settings.chat_completion_source)) {
|
|
||||||
if (!Array.isArray(data?.choices)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find a choice with 0-index
|
|
||||||
const choice = data.choices.find(choice => choice.index === 0);
|
|
||||||
|
|
||||||
if (!choice) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const toolCalls = choice.message.tool_calls;
|
|
||||||
|
|
||||||
if (!Array.isArray(toolCalls)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const toolCall of toolCalls) {
|
|
||||||
if (typeof toolCall.function !== 'object') {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** @type {FunctionToolCall} */
|
|
||||||
const args = toolCall.function;
|
|
||||||
console.log('Function tool call:', toolCall);
|
|
||||||
await eventSource.emit(event_types.LLM_FUNCTION_TOOL_CALL, args);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ([chat_completion_sources.CLAUDE].includes(oai_settings.chat_completion_source)) {
|
|
||||||
if (!Array.isArray(data?.content)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const content of data.content) {
|
|
||||||
if (content.type === 'tool_use') {
|
|
||||||
/** @type {FunctionToolCall} */
|
|
||||||
const args = { name: content.name, arguments: JSON.stringify(content.input) };
|
|
||||||
await eventSource.emit(event_types.LLM_FUNCTION_TOOL_CALL, args);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ([chat_completion_sources.COHERE].includes(oai_settings.chat_completion_source)) {
|
|
||||||
if (!Array.isArray(data?.tool_calls)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const toolCall of data.tool_calls) {
|
|
||||||
/** @type {FunctionToolCall} */
|
|
||||||
const args = { name: toolCall.name, arguments: JSON.stringify(toolCall.parameters) };
|
|
||||||
console.log('Function tool call:', toolCall);
|
|
||||||
await eventSource.emit(event_types.LLM_FUNCTION_TOOL_CALL, args);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export function isFunctionCallingSupported() {
|
|
||||||
if (main_api !== 'openai') {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!oai_settings.function_calling) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const supportedSources = [
|
|
||||||
chat_completion_sources.OPENAI,
|
|
||||||
chat_completion_sources.COHERE,
|
|
||||||
chat_completion_sources.CUSTOM,
|
|
||||||
chat_completion_sources.MISTRALAI,
|
|
||||||
chat_completion_sources.CLAUDE,
|
|
||||||
chat_completion_sources.OPENROUTER,
|
|
||||||
chat_completion_sources.GROQ,
|
|
||||||
];
|
|
||||||
return supportedSources.includes(oai_settings.chat_completion_source);
|
|
||||||
}
|
|
||||||
|
|
||||||
function getStreamingReply(data) {
|
function getStreamingReply(data) {
|
||||||
if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) {
|
if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) {
|
||||||
return data?.delta?.text || '';
|
return data?.delta?.text || '';
|
||||||
|
@ -4019,7 +3886,7 @@ async function onModelChange() {
|
||||||
$('#openai_max_context').attr('max', max_32k);
|
$('#openai_max_context').attr('max', max_32k);
|
||||||
} else if (value === 'text-bison-001') {
|
} else if (value === 'text-bison-001') {
|
||||||
$('#openai_max_context').attr('max', max_8k);
|
$('#openai_max_context').attr('max', max_8k);
|
||||||
// The ultra endpoints are possibly dead:
|
// The ultra endpoints are possibly dead:
|
||||||
} else if (value.includes('gemini-1.0-ultra') || value === 'gemini-ultra') {
|
} else if (value.includes('gemini-1.0-ultra') || value === 'gemini-ultra') {
|
||||||
$('#openai_max_context').attr('max', max_32k);
|
$('#openai_max_context').attr('max', max_32k);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -0,0 +1,381 @@
|
||||||
|
import { chat, main_api } from '../script.js';
|
||||||
|
import { chat_completion_sources, oai_settings } from './openai.js';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @typedef {object} ToolInvocation
|
||||||
|
* @property {string} id - A unique identifier for the tool invocation.
|
||||||
|
* @property {string} name - The name of the tool.
|
||||||
|
* @property {string} parameters - The parameters for the tool invocation.
|
||||||
|
* @property {string} result - The result of the tool invocation.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A class that represents a tool definition.
|
||||||
|
*/
|
||||||
|
class ToolDefinition {
|
||||||
|
/**
|
||||||
|
* A unique name for the tool.
|
||||||
|
* @type {string}
|
||||||
|
*/
|
||||||
|
#name;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A description of what the tool does.
|
||||||
|
* @type {string}
|
||||||
|
*/
|
||||||
|
#description;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A JSON schema for the parameters that the tool accepts.
|
||||||
|
* @type {object}
|
||||||
|
*/
|
||||||
|
#parameters;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A function that will be called when the tool is executed.
|
||||||
|
* @type {function}
|
||||||
|
*/
|
||||||
|
#action;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new ToolDefinition.
|
||||||
|
* @param {string} name A unique name for the tool.
|
||||||
|
* @param {string} description A description of what the tool does.
|
||||||
|
* @param {object} parameters A JSON schema for the parameters that the tool accepts.
|
||||||
|
* @param {function} action A function that will be called when the tool is executed.
|
||||||
|
*/
|
||||||
|
constructor(name, description, parameters, action) {
|
||||||
|
this.#name = name;
|
||||||
|
this.#description = description;
|
||||||
|
this.#parameters = parameters;
|
||||||
|
this.#action = action;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts the ToolDefinition to an OpenAI API representation
|
||||||
|
* @returns {object} OpenAI API representation of the tool.
|
||||||
|
*/
|
||||||
|
toFunctionOpenAI() {
|
||||||
|
return {
|
||||||
|
type: 'function',
|
||||||
|
function: {
|
||||||
|
name: this.#name,
|
||||||
|
description: this.#description,
|
||||||
|
parameters: this.#parameters,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Invokes the tool with the given parameters.
|
||||||
|
* @param {object} parameters The parameters to pass to the tool.
|
||||||
|
* @returns {Promise<any>} The result of the tool's action function.
|
||||||
|
*/
|
||||||
|
async invoke(parameters) {
|
||||||
|
return await this.#action(parameters);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A class that manages the registration and invocation of tools.
|
||||||
|
*/
|
||||||
|
export class ToolManager {
|
||||||
|
/**
|
||||||
|
* A map of tool names to tool definitions.
|
||||||
|
* @type {Map<string, ToolDefinition>}
|
||||||
|
*/
|
||||||
|
static #tools = new Map();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns an Array of all tools that have been registered.
|
||||||
|
* @type {ToolDefinition[]}
|
||||||
|
*/
|
||||||
|
static get tools() {
|
||||||
|
return Array.from(this.#tools.values());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Registers a new tool with the tool registry.
|
||||||
|
* @param {string} name The name of the tool.
|
||||||
|
* @param {string} description A description of what the tool does.
|
||||||
|
* @param {object} parameters A JSON schema for the parameters that the tool accepts.
|
||||||
|
* @param {function} action A function that will be called when the tool is executed.
|
||||||
|
*/
|
||||||
|
static registerFunctionTool(name, description, parameters, action) {
|
||||||
|
if (this.#tools.has(name)) {
|
||||||
|
console.warn(`A tool with the name "${name}" has already been registered. The definition will be overwritten.`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const definition = new ToolDefinition(name, description, parameters, action);
|
||||||
|
this.#tools.set(name, definition);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Removes a tool from the tool registry.
|
||||||
|
* @param {string} name The name of the tool to unregister.
|
||||||
|
*/
|
||||||
|
static unregisterFunctionTool(name) {
|
||||||
|
if (!this.#tools.has(name)) {
|
||||||
|
console.warn(`No tool with the name "${name}" has been registered.`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.#tools.delete(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Invokes a tool by name. Returns the result of the tool's action function.
|
||||||
|
* @param {string} name The name of the tool to invoke.
|
||||||
|
* @param {object} parameters Function parameters. For example, if the tool requires a "name" parameter, you would pass {name: "value"}.
|
||||||
|
* @returns {Promise<string|null>} The result of the tool's action function. If an error occurs, null is returned. Non-string results are JSON-stringified.
|
||||||
|
*/
|
||||||
|
static async invokeFunctionTool(name, parameters) {
|
||||||
|
try {
|
||||||
|
if (!this.#tools.has(name)) {
|
||||||
|
throw new Error(`No tool with the name "${name}" has been registered.`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const invokeParameters = typeof parameters === 'string' ? JSON.parse(parameters) : parameters;
|
||||||
|
const tool = this.#tools.get(name);
|
||||||
|
const result = await tool.invoke(invokeParameters);
|
||||||
|
return typeof result === 'string' ? result : JSON.stringify(result);
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`An error occurred while invoking the tool "${name}":`, error);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Register function tools for the next chat completion request.
|
||||||
|
* @param {object} data Generation data
|
||||||
|
*/
|
||||||
|
static async registerFunctionToolsOpenAI(data) {
|
||||||
|
const tools = [];
|
||||||
|
|
||||||
|
for (const tool of ToolManager.tools) {
|
||||||
|
tools.push(tool.toFunctionOpenAI());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tools.length) {
|
||||||
|
console.log('Registered function tools:', tools);
|
||||||
|
|
||||||
|
data['tools'] = tools;
|
||||||
|
data['tool_choice'] = 'auto';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Utility function to parse tool calls from a parsed response.
|
||||||
|
* @param {any[]} toolCalls The tool calls to update.
|
||||||
|
* @param {any} parsed The parsed response from the OpenAI API.
|
||||||
|
* @returns {void}
|
||||||
|
*/
|
||||||
|
static parseToolCalls(toolCalls, parsed) {
|
||||||
|
if (!Array.isArray(parsed?.choices)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (const choice of parsed.choices) {
|
||||||
|
const choiceIndex = (typeof choice.index === 'number') ? choice.index : null;
|
||||||
|
const choiceDelta = choice.delta;
|
||||||
|
|
||||||
|
if (choiceIndex === null || !choiceDelta) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const toolCallDeltas = choiceDelta?.tool_calls;
|
||||||
|
|
||||||
|
if (!Array.isArray(toolCallDeltas)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!Array.isArray(toolCalls[choiceIndex])) {
|
||||||
|
toolCalls[choiceIndex] = [];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const toolCallDelta of toolCallDeltas) {
|
||||||
|
const toolCallIndex = (typeof toolCallDelta?.index === 'number') ? toolCallDelta.index : null;
|
||||||
|
|
||||||
|
if (toolCallIndex === null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (toolCalls[choiceIndex][toolCallIndex] === undefined) {
|
||||||
|
toolCalls[choiceIndex][toolCallIndex] = {};
|
||||||
|
}
|
||||||
|
|
||||||
|
const targetToolCall = toolCalls[choiceIndex][toolCallIndex];
|
||||||
|
|
||||||
|
ToolManager.#applyToolCallDelta(targetToolCall, toolCallDelta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static #applyToolCallDelta(target, delta) {
|
||||||
|
for (const key in delta) {
|
||||||
|
if (!delta.hasOwnProperty(key)) continue;
|
||||||
|
|
||||||
|
const deltaValue = delta[key];
|
||||||
|
const targetValue = target[key];
|
||||||
|
|
||||||
|
if (deltaValue === null || deltaValue === undefined) {
|
||||||
|
target[key] = deltaValue;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (typeof deltaValue === 'string') {
|
||||||
|
if (typeof targetValue === 'string') {
|
||||||
|
// Concatenate strings
|
||||||
|
target[key] = targetValue + deltaValue;
|
||||||
|
} else {
|
||||||
|
target[key] = deltaValue;
|
||||||
|
}
|
||||||
|
} else if (typeof deltaValue === 'object' && !Array.isArray(deltaValue)) {
|
||||||
|
if (typeof targetValue !== 'object' || targetValue === null || Array.isArray(targetValue)) {
|
||||||
|
target[key] = {};
|
||||||
|
}
|
||||||
|
// Recursively apply deltas to nested objects
|
||||||
|
ToolManager.#applyToolCallDelta(target[key], deltaValue);
|
||||||
|
} else {
|
||||||
|
// Assign other types directly
|
||||||
|
target[key] = deltaValue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static isFunctionCallingSupported() {
|
||||||
|
if (main_api !== 'openai') {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!oai_settings.function_calling) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const supportedSources = [
|
||||||
|
chat_completion_sources.OPENAI,
|
||||||
|
//chat_completion_sources.COHERE,
|
||||||
|
chat_completion_sources.CUSTOM,
|
||||||
|
chat_completion_sources.MISTRALAI,
|
||||||
|
//chat_completion_sources.CLAUDE,
|
||||||
|
chat_completion_sources.OPENROUTER,
|
||||||
|
chat_completion_sources.GROQ,
|
||||||
|
];
|
||||||
|
return supportedSources.includes(oai_settings.chat_completion_source);
|
||||||
|
}
|
||||||
|
|
||||||
|
static #getToolCallsFromData(data) {
|
||||||
|
// Parsed tool calls from streaming data
|
||||||
|
if (Array.isArray(data) && data.length > 0) {
|
||||||
|
return data[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parsed tool calls from non-streaming data
|
||||||
|
if (!Array.isArray(data?.choices)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find a choice with 0-index
|
||||||
|
const choice = data.choices.find(choice => choice.index === 0);
|
||||||
|
|
||||||
|
if (!choice) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
return choice.message.tool_calls;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check for function tool calls in the response data and invoke them.
|
||||||
|
* @param {any} data Reply data
|
||||||
|
* @returns {Promise<ToolInvocation[]>} Successful tool invocations
|
||||||
|
*/
|
||||||
|
static async checkFunctionToolCalls(data) {
|
||||||
|
if (!ToolManager.isFunctionCallingSupported()) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @type {ToolInvocation[]} */
|
||||||
|
const invocations = [];
|
||||||
|
const toolCalls = ToolManager.#getToolCallsFromData(data);
|
||||||
|
const oaiCompat = [
|
||||||
|
chat_completion_sources.OPENAI,
|
||||||
|
chat_completion_sources.CUSTOM,
|
||||||
|
chat_completion_sources.MISTRALAI,
|
||||||
|
chat_completion_sources.OPENROUTER,
|
||||||
|
chat_completion_sources.GROQ,
|
||||||
|
];
|
||||||
|
|
||||||
|
if (oaiCompat.includes(oai_settings.chat_completion_source)) {
|
||||||
|
if (!Array.isArray(toolCalls)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const toolCall of toolCalls) {
|
||||||
|
if (typeof toolCall.function !== 'object') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log('Function tool call:', toolCall);
|
||||||
|
const id = toolCall.id;
|
||||||
|
const parameters = toolCall.function.arguments;
|
||||||
|
const name = toolCall.function.name;
|
||||||
|
|
||||||
|
toastr.info('Invoking function tool: ' + name);
|
||||||
|
const result = await ToolManager.invokeFunctionTool(name, parameters);
|
||||||
|
toastr.info('Function tool result: ' + result);
|
||||||
|
|
||||||
|
// Save a successful invocation
|
||||||
|
if (result) {
|
||||||
|
invocations.push({ id, name, result, parameters });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
if ([chat_completion_sources.CLAUDE].includes(oai_settings.chat_completion_source)) {
|
||||||
|
if (!Array.isArray(data?.content)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const content of data.content) {
|
||||||
|
if (content.type === 'tool_use') {
|
||||||
|
const args = { name: content.name, arguments: JSON.stringify(content.input) };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*
|
||||||
|
if ([chat_completion_sources.COHERE].includes(oai_settings.chat_completion_source)) {
|
||||||
|
if (!Array.isArray(data?.tool_calls)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const toolCall of data.tool_calls) {
|
||||||
|
const args = { name: toolCall.name, arguments: JSON.stringify(toolCall.parameters) };
|
||||||
|
console.log('Function tool call:', toolCall);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
return invocations;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Saves function tool invocations to the last user chat message extra metadata.
|
||||||
|
* @param {ToolInvocation[]} invocations
|
||||||
|
*/
|
||||||
|
static saveFunctionToolInvocations(invocations) {
|
||||||
|
for (let index = chat.length - 1; index >= 0; index--) {
|
||||||
|
const message = chat[index];
|
||||||
|
if (message.is_user) {
|
||||||
|
if (!message.extra || typeof message.extra !== 'object') {
|
||||||
|
message.extra = {};
|
||||||
|
}
|
||||||
|
message.extra.tool_invocations = invocations;
|
||||||
|
debugger;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -121,18 +121,20 @@ async function sendClaudeRequest(request, response) {
|
||||||
? [{ type: 'text', text: convertedPrompt.systemPrompt, cache_control: { type: 'ephemeral' } }]
|
? [{ type: 'text', text: convertedPrompt.systemPrompt, cache_control: { type: 'ephemeral' } }]
|
||||||
: convertedPrompt.systemPrompt;
|
: convertedPrompt.systemPrompt;
|
||||||
}
|
}
|
||||||
|
/*
|
||||||
if (Array.isArray(request.body.tools) && request.body.tools.length > 0) {
|
if (Array.isArray(request.body.tools) && request.body.tools.length > 0) {
|
||||||
// Claude doesn't do prefills on function calls, and doesn't allow empty messages
|
// Claude doesn't do prefills on function calls, and doesn't allow empty messages
|
||||||
if (convertedPrompt.messages.length && convertedPrompt.messages[convertedPrompt.messages.length - 1].role === 'assistant') {
|
if (convertedPrompt.messages.length && convertedPrompt.messages[convertedPrompt.messages.length - 1].role === 'assistant') {
|
||||||
convertedPrompt.messages.push({ role: 'user', content: '.' });
|
convertedPrompt.messages.push({ role: 'user', content: '.' });
|
||||||
}
|
}
|
||||||
additionalHeaders['anthropic-beta'] = 'tools-2024-05-16';
|
additionalHeaders['anthropic-beta'] = 'tools-2024-05-16';
|
||||||
requestBody.tool_choice = { type: request.body.tool_choice === 'required' ? 'any' : 'auto' };
|
requestBody.tool_choice = { type: request.body.tool_choice };
|
||||||
requestBody.tools = request.body.tools
|
requestBody.tools = request.body.tools
|
||||||
.filter(tool => tool.type === 'function')
|
.filter(tool => tool.type === 'function')
|
||||||
.map(tool => tool.function)
|
.map(tool => tool.function)
|
||||||
.map(fn => ({ name: fn.name, description: fn.description, input_schema: fn.parameters }));
|
.map(fn => ({ name: fn.name, description: fn.description, input_schema: fn.parameters }));
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
if (enableSystemPromptCache) {
|
if (enableSystemPromptCache) {
|
||||||
additionalHeaders['anthropic-beta'] = 'prompt-caching-2024-07-31';
|
additionalHeaders['anthropic-beta'] = 'prompt-caching-2024-07-31';
|
||||||
}
|
}
|
||||||
|
@ -479,7 +481,7 @@ async function sendMistralAIRequest(request, response) {
|
||||||
|
|
||||||
if (Array.isArray(request.body.tools) && request.body.tools.length > 0) {
|
if (Array.isArray(request.body.tools) && request.body.tools.length > 0) {
|
||||||
requestBody['tools'] = request.body.tools;
|
requestBody['tools'] = request.body.tools;
|
||||||
requestBody['tool_choice'] = request.body.tool_choice === 'required' ? 'any' : 'auto';
|
requestBody['tool_choice'] = request.body.tool_choice;
|
||||||
}
|
}
|
||||||
|
|
||||||
const config = {
|
const config = {
|
||||||
|
@ -549,11 +551,13 @@ async function sendCohereRequest(request, response) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
if (Array.isArray(request.body.tools) && request.body.tools.length > 0) {
|
if (Array.isArray(request.body.tools) && request.body.tools.length > 0) {
|
||||||
tools.push(...convertCohereTools(request.body.tools));
|
tools.push(...convertCohereTools(request.body.tools));
|
||||||
// Can't have both connectors and tools in the same request
|
// Can't have both connectors and tools in the same request
|
||||||
connectors.splice(0, connectors.length);
|
connectors.splice(0, connectors.length);
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
// https://docs.cohere.com/reference/chat
|
// https://docs.cohere.com/reference/chat
|
||||||
const requestBody = {
|
const requestBody = {
|
||||||
|
@ -910,18 +914,6 @@ router.post('/generate', jsonParser, function (request, response) {
|
||||||
apiKey = readSecret(request.user.directories, SECRET_KEYS.GROQ);
|
apiKey = readSecret(request.user.directories, SECRET_KEYS.GROQ);
|
||||||
headers = {};
|
headers = {};
|
||||||
bodyParams = {};
|
bodyParams = {};
|
||||||
|
|
||||||
// 'required' tool choice is not supported by Groq
|
|
||||||
if (request.body.tool_choice === 'required') {
|
|
||||||
if (Array.isArray(request.body.tools) && request.body.tools.length > 0) {
|
|
||||||
request.body.tool_choice = request.body.tools.length > 1
|
|
||||||
? 'auto' :
|
|
||||||
{ type: 'function', function: { name: request.body.tools[0]?.function?.name } };
|
|
||||||
|
|
||||||
} else {
|
|
||||||
request.body.tool_choice = 'none';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.ZEROONEAI) {
|
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.ZEROONEAI) {
|
||||||
apiUrl = API_01AI;
|
apiUrl = API_01AI;
|
||||||
apiKey = readSecret(request.user.directories, SECRET_KEYS.ZEROONEAI);
|
apiKey = readSecret(request.user.directories, SECRET_KEYS.ZEROONEAI);
|
||||||
|
@ -958,7 +950,7 @@ router.post('/generate', jsonParser, function (request, response) {
|
||||||
controller.abort();
|
controller.abort();
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!isTextCompletion) {
|
if (!isTextCompletion && Array.isArray(request.body.tools) && request.body.tools.length > 0) {
|
||||||
bodyParams['tools'] = request.body.tools;
|
bodyParams['tools'] = request.body.tools;
|
||||||
bodyParams['tool_choice'] = request.body.tool_choice;
|
bodyParams['tool_choice'] = request.body.tool_choice;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue