From 6185974e17a77d7d5caacd3ff2853413283b5ac7 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Sat, 5 Oct 2024 18:04:08 +0300 Subject: [PATCH] Claude: Use multi-part system prompt, cache tools --- src/endpoints/backends/chat-completions.js | 16 +++++++++++----- src/prompt-converters.js | 6 +++--- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/endpoints/backends/chat-completions.js b/src/endpoints/backends/chat-completions.js index 603f4e7e4..4836a9f8b 100644 --- a/src/endpoints/backends/chat-completions.js +++ b/src/endpoints/backends/chat-completions.js @@ -87,7 +87,7 @@ async function sendClaudeRequest(request, response) { const apiUrl = new URL(request.body.reverse_proxy || API_CLAUDE).toString(); const apiKey = request.body.reverse_proxy ? request.body.proxy_password : readSecret(request.user.directories, SECRET_KEYS.CLAUDE); const divider = '-'.repeat(process.stdout.columns); - const enableSystemPromptCache = getConfigValue('claude.enableSystemPromptCache', false); + const enableSystemPromptCache = getConfigValue('claude.enableSystemPromptCache', false) && request.body.model.startsWith('claude-3'); if (!apiKey) { console.log(color.red(`Claude API key is missing.\n${divider}`)); @@ -110,7 +110,7 @@ async function sendClaudeRequest(request, response) { } const requestBody = { - /** @type {any} */ system: '', + /** @type {any} */ system: [], messages: convertedPrompt.messages, model: request.body.model, max_tokens: request.body.max_tokens, @@ -121,9 +121,11 @@ async function sendClaudeRequest(request, response) { stream: request.body.stream, }; if (useSystemPrompt) { - requestBody.system = enableSystemPromptCache - ? [{ type: 'text', text: convertedPrompt.systemPrompt, cache_control: { type: 'ephemeral' } }] - : convertedPrompt.systemPrompt; + if (enableSystemPromptCache && Array.isArray(convertedPrompt.systemPrompt) && convertedPrompt.systemPrompt.length) { + convertedPrompt.systemPrompt[convertedPrompt.systemPrompt.length - 1]['cache_control'] = { type: 'ephemeral' }; + } + + requestBody.system = convertedPrompt.systemPrompt; } else { delete requestBody.system; } @@ -138,6 +140,10 @@ async function sendClaudeRequest(request, response) { .filter(tool => tool.type === 'function') .map(tool => tool.function) .map(fn => ({ name: fn.name, description: fn.description, input_schema: fn.parameters })); + + if (enableSystemPromptCache && requestBody.tools.length) { + requestBody.tools[requestBody.tools.length - 1]['cache_control'] = { type: 'ephemeral' }; + } } if (enableSystemPromptCache) { additionalHeaders['anthropic-beta'] = 'prompt-caching-2024-07-31'; diff --git a/src/prompt-converters.js b/src/prompt-converters.js index 57362b8e2..c27ff8f8a 100644 --- a/src/prompt-converters.js +++ b/src/prompt-converters.js @@ -95,7 +95,7 @@ function convertClaudePrompt(messages, addAssistantPostfix, addAssistantPrefill, * @param {string} userName User name */ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFix, charName = '', userName = '') { - let systemPrompt = ''; + let systemPrompt = []; if (useSysPrompt) { // Collect all the system messages up until the first instance of a non-system message, and then remove them from the messages array. let i; @@ -114,7 +114,7 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFi messages[i].content = `${charName}: ${messages[i].content}`; } } - systemPrompt += `${messages[i].content}\n\n`; + systemPrompt.push({ type: 'text', text: messages[i].content }); } messages.splice(0, i); @@ -246,7 +246,7 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFi } }); - return { messages: mergedMessages, systemPrompt: systemPrompt.trim() }; + return { messages: mergedMessages, systemPrompt: systemPrompt }; } /**