mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-03-02 19:07:40 +01:00
Function calling for Cohere
This commit is contained in:
parent
dc8530049f
commit
fa6fc45e6f
@ -1739,7 +1739,7 @@
|
|||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="range-block" data-source="openai,custom">
|
<div class="range-block" data-source="openai,cohere,custom">
|
||||||
<label for="openai_function_calling" class="checkbox_label flexWrap widthFreeExpand">
|
<label for="openai_function_calling" class="checkbox_label flexWrap widthFreeExpand">
|
||||||
<input id="openai_function_calling" type="checkbox" />
|
<input id="openai_function_calling" type="checkbox" />
|
||||||
<span data-i18n="Enable function calling">Enable function calling</span>
|
<span data-i18n="Enable function calling">Enable function calling</span>
|
||||||
|
@ -1063,6 +1063,7 @@ function onFunctionToolRegister(args) {
|
|||||||
emotion: {
|
emotion: {
|
||||||
type: 'string',
|
type: 'string',
|
||||||
enum: emotions,
|
enum: emotions,
|
||||||
|
description: `One of the following: ${JSON.stringify(emotions)}`,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
required: [
|
required: [
|
||||||
|
@ -1968,7 +1968,8 @@ async function registerFunctionTools(type, data) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function checkFunctionToolCalls(data) {
|
async function checkFunctionToolCalls(data) {
|
||||||
if (!Array.isArray(data.choices)) {
|
if ([chat_completion_sources.OPENAI, chat_completion_sources.CUSTOM].includes(oai_settings.chat_completion_source)) {
|
||||||
|
if (!Array.isArray(data?.choices)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1996,6 +1997,21 @@ async function checkFunctionToolCalls(data) {
|
|||||||
await eventSource.emit(event_types.LLM_FUNCTION_TOOL_CALL, args);
|
await eventSource.emit(event_types.LLM_FUNCTION_TOOL_CALL, args);
|
||||||
data.allowEmptyResponse = true;
|
data.allowEmptyResponse = true;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ([chat_completion_sources.COHERE].includes(oai_settings.chat_completion_source)) {
|
||||||
|
if (!Array.isArray(data?.tool_calls)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const toolCall of data.tool_calls) {
|
||||||
|
/** @type {FunctionToolCall} */
|
||||||
|
const args = { name: toolCall.name, arguments: JSON.stringify(toolCall.parameters) };
|
||||||
|
console.log('Function tool call:', toolCall);
|
||||||
|
await eventSource.emit(event_types.LLM_FUNCTION_TOOL_CALL, args);
|
||||||
|
data.allowEmptyResponse = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export function isFunctionCallingSupported() {
|
export function isFunctionCallingSupported() {
|
||||||
@ -2009,6 +2025,7 @@ export function isFunctionCallingSupported() {
|
|||||||
|
|
||||||
const supportedSources = [
|
const supportedSources = [
|
||||||
chat_completion_sources.OPENAI,
|
chat_completion_sources.OPENAI,
|
||||||
|
chat_completion_sources.COHERE,
|
||||||
chat_completion_sources.CUSTOM,
|
chat_completion_sources.CUSTOM,
|
||||||
];
|
];
|
||||||
return supportedSources.includes(oai_settings.chat_completion_source);
|
return supportedSources.includes(oai_settings.chat_completion_source);
|
||||||
@ -3964,7 +3981,7 @@ async function onModelChange() {
|
|||||||
else if (['command-r', 'command-r-plus'].includes(oai_settings.cohere_model)) {
|
else if (['command-r', 'command-r-plus'].includes(oai_settings.cohere_model)) {
|
||||||
$('#openai_max_context').attr('max', max_128k);
|
$('#openai_max_context').attr('max', max_128k);
|
||||||
}
|
}
|
||||||
else if(['c4ai-aya-23'].includes(oai_settings.cohere_model)) {
|
else if (['c4ai-aya-23'].includes(oai_settings.cohere_model)) {
|
||||||
$('#openai_max_context').attr('max', max_8k);
|
$('#openai_max_context').attr('max', max_8k);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
@ -4555,7 +4572,8 @@ function runProxyCallback(_, value) {
|
|||||||
return foundName;
|
return foundName;
|
||||||
}
|
}
|
||||||
|
|
||||||
SlashCommandParser.addCommandObject(SlashCommand.fromProps({ name: 'proxy',
|
SlashCommandParser.addCommandObject(SlashCommand.fromProps({
|
||||||
|
name: 'proxy',
|
||||||
callback: runProxyCallback,
|
callback: runProxyCallback,
|
||||||
returns: 'current proxy',
|
returns: 'current proxy',
|
||||||
namedArgumentList: [],
|
namedArgumentList: [],
|
||||||
|
@ -5,7 +5,7 @@ const Readable = require('stream').Readable;
|
|||||||
const { jsonParser } = require('../../express-common');
|
const { jsonParser } = require('../../express-common');
|
||||||
const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY, OPENROUTER_HEADERS } = require('../../constants');
|
const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY, OPENROUTER_HEADERS } = require('../../constants');
|
||||||
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4, mergeObjectWithYaml, excludeKeysByYaml, color } = require('../../util');
|
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4, mergeObjectWithYaml, excludeKeysByYaml, color } = require('../../util');
|
||||||
const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages } = require('../../prompt-converters');
|
const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertCohereTools } = require('../../prompt-converters');
|
||||||
|
|
||||||
const { readSecret, SECRET_KEYS } = require('../secrets');
|
const { readSecret, SECRET_KEYS } = require('../secrets');
|
||||||
const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sentencepieceTokenizers, TEXT_COMPLETION_MODELS } = require('../tokenizers');
|
const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sentencepieceTokenizers, TEXT_COMPLETION_MODELS } = require('../tokenizers');
|
||||||
@ -544,6 +544,7 @@ async function sendCohereRequest(request, response) {
|
|||||||
try {
|
try {
|
||||||
const convertedHistory = convertCohereMessages(request.body.messages, request.body.char_name, request.body.user_name);
|
const convertedHistory = convertCohereMessages(request.body.messages, request.body.char_name, request.body.user_name);
|
||||||
const connectors = [];
|
const connectors = [];
|
||||||
|
const tools = [];
|
||||||
|
|
||||||
if (request.body.websearch) {
|
if (request.body.websearch) {
|
||||||
connectors.push({
|
connectors.push({
|
||||||
@ -551,6 +552,12 @@ async function sendCohereRequest(request, response) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (Array.isArray(request.body.tools) && request.body.tools.length > 0) {
|
||||||
|
tools.push(...convertCohereTools(request.body.tools));
|
||||||
|
// Can't have both connectors and tools in the same request
|
||||||
|
connectors.splice(0, connectors.length);
|
||||||
|
}
|
||||||
|
|
||||||
// https://docs.cohere.com/reference/chat
|
// https://docs.cohere.com/reference/chat
|
||||||
const requestBody = {
|
const requestBody = {
|
||||||
stream: Boolean(request.body.stream),
|
stream: Boolean(request.body.stream),
|
||||||
@ -569,8 +576,7 @@ async function sendCohereRequest(request, response) {
|
|||||||
prompt_truncation: 'AUTO_PRESERVE_ORDER',
|
prompt_truncation: 'AUTO_PRESERVE_ORDER',
|
||||||
connectors: connectors,
|
connectors: connectors,
|
||||||
documents: [],
|
documents: [],
|
||||||
tools: [],
|
tools: tools,
|
||||||
tool_results: [],
|
|
||||||
search_queries_only: false,
|
search_queries_only: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -451,6 +451,76 @@ function convertTextCompletionPrompt(messages) {
|
|||||||
return messageStrings.join('\n') + '\nassistant:';
|
return messageStrings.join('\n') + '\nassistant:';
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert OpenAI Chat Completion tools to the format used by Cohere.
|
||||||
|
* @param {object[]} tools OpenAI Chat Completion tool definitions
|
||||||
|
*/
|
||||||
|
function convertCohereTools(tools) {
|
||||||
|
if (!Array.isArray(tools) || tools.length === 0) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const jsonSchemaToPythonTypes = {
|
||||||
|
'string': 'str',
|
||||||
|
'number': 'float',
|
||||||
|
'integer': 'int',
|
||||||
|
'boolean': 'bool',
|
||||||
|
'array': 'list',
|
||||||
|
'object': 'dict',
|
||||||
|
};
|
||||||
|
|
||||||
|
const cohereTools = [];
|
||||||
|
|
||||||
|
for (const tool of tools) {
|
||||||
|
if (tool?.type !== 'function') {
|
||||||
|
console.log(`Unsupported tool type: ${tool.type}`);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const name = tool?.function?.name;
|
||||||
|
const description = tool?.function?.description;
|
||||||
|
const properties = tool?.function?.parameters?.properties;
|
||||||
|
const required = tool?.function?.parameters?.required;
|
||||||
|
const parameters = {};
|
||||||
|
|
||||||
|
if (!name) {
|
||||||
|
console.log('Tool name is missing');
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!description) {
|
||||||
|
console.log('Tool description is missing');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!properties || typeof properties !== 'object') {
|
||||||
|
console.log(`No properties found for tool: ${tool?.function?.name}`);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const property in properties) {
|
||||||
|
const parameterDefinition = properties[property];
|
||||||
|
const description = parameterDefinition.description || (parameterDefinition.enum ? JSON.stringify(parameterDefinition.enum) : '');
|
||||||
|
const type = jsonSchemaToPythonTypes[parameterDefinition.type] || 'str';
|
||||||
|
const isRequired = Array.isArray(required) && required.includes(property);
|
||||||
|
parameters[property] = {
|
||||||
|
description: description,
|
||||||
|
type: type,
|
||||||
|
required: isRequired,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const cohereTool = {
|
||||||
|
name: tool.function.name,
|
||||||
|
description: tool.function.description,
|
||||||
|
parameter_definitions: parameters,
|
||||||
|
};
|
||||||
|
|
||||||
|
cohereTools.push(cohereTool);
|
||||||
|
}
|
||||||
|
|
||||||
|
return cohereTools;
|
||||||
|
}
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
convertClaudePrompt,
|
convertClaudePrompt,
|
||||||
convertClaudeMessages,
|
convertClaudeMessages,
|
||||||
@ -458,4 +528,5 @@ module.exports = {
|
|||||||
convertTextCompletionPrompt,
|
convertTextCompletionPrompt,
|
||||||
convertCohereMessages,
|
convertCohereMessages,
|
||||||
convertMistralMessages,
|
convertMistralMessages,
|
||||||
|
convertCohereTools,
|
||||||
};
|
};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user