Claude: new prompt converter + non-streaming tools

This commit is contained in:
Cohee 2024-10-04 03:41:25 +03:00
parent 559f1b81f7
commit c3c10a629e
3 changed files with 141 additions and 99 deletions

View File

@ -339,10 +339,9 @@ export class ToolManager {
const supportedSources = [ const supportedSources = [
chat_completion_sources.OPENAI, chat_completion_sources.OPENAI,
//chat_completion_sources.COHERE,
chat_completion_sources.CUSTOM, chat_completion_sources.CUSTOM,
chat_completion_sources.MISTRALAI, chat_completion_sources.MISTRALAI,
//chat_completion_sources.CLAUDE, chat_completion_sources.CLAUDE,
chat_completion_sources.OPENROUTER, chat_completion_sources.OPENROUTER,
chat_completion_sources.GROQ, chat_completion_sources.GROQ,
]; ];
@ -372,18 +371,29 @@ export class ToolManager {
} }
// Parsed tool calls from non-streaming data // Parsed tool calls from non-streaming data
if (!Array.isArray(data?.choices)) { if (Array.isArray(data?.choices)) {
return;
}
// Find a choice with 0-index // Find a choice with 0-index
const choice = data.choices.find(choice => choice.index === 0); const choice = data.choices.find(choice => choice.index === 0);
if (!choice) { if (choice) {
return; return choice.message.tool_calls;
}
} }
return choice.message.tool_calls; if (Array.isArray(data?.content)) {
// Claude tool calls to OpenAI tool calls
const content = data.content.filter(c => c.type === 'tool_use').map(c => {
return {
id: c.id,
function: {
name: c.name,
arguments: c.input,
},
};
});
return content;
}
} }
/** /**
@ -407,7 +417,6 @@ export class ToolManager {
chat_completion_sources.GROQ, chat_completion_sources.GROQ,
]; ];
if (oaiCompatibleSources.includes(oai_settings.chat_completion_source)) {
if (!Array.isArray(toolCalls)) { if (!Array.isArray(toolCalls)) {
return result; return result;
} }
@ -445,21 +454,6 @@ export class ToolManager {
}; };
result.invocations.push(invocation); result.invocations.push(invocation);
} }
}
/*
if ([chat_completion_sources.CLAUDE].includes(oai_settings.chat_completion_source)) {
if (!Array.isArray(data?.content)) {
return;
}
for (const content of data.content) {
if (content.type === 'tool_use') {
const args = { name: content.name, arguments: JSON.stringify(content.input) };
}
}
}
*/
return result; return result;
} }
@ -491,6 +485,9 @@ export class ToolManager {
* @param {ToolInvocation[]} invocations Successful tool invocations * @param {ToolInvocation[]} invocations Successful tool invocations
*/ */
static saveFunctionToolInvocations(invocations) { static saveFunctionToolInvocations(invocations) {
if (!Array.isArray(invocations) || invocations.length === 0) {
return;
}
const message = { const message = {
name: systemUserName, name: systemUserName,
force_avatar: system_avatar, force_avatar: system_avatar,

View File

@ -124,7 +124,6 @@ async function sendClaudeRequest(request, response) {
} else { } else {
delete requestBody.system; delete requestBody.system;
} }
/*
if (Array.isArray(request.body.tools) && request.body.tools.length > 0) { if (Array.isArray(request.body.tools) && request.body.tools.length > 0) {
// Claude doesn't do prefills on function calls, and doesn't allow empty messages // Claude doesn't do prefills on function calls, and doesn't allow empty messages
if (convertedPrompt.messages.length && convertedPrompt.messages[convertedPrompt.messages.length - 1].role === 'assistant') { if (convertedPrompt.messages.length && convertedPrompt.messages[convertedPrompt.messages.length - 1].role === 'assistant') {
@ -137,7 +136,6 @@ async function sendClaudeRequest(request, response) {
.map(tool => tool.function) .map(tool => tool.function)
.map(fn => ({ name: fn.name, description: fn.description, input_schema: fn.parameters })); .map(fn => ({ name: fn.name, description: fn.description, input_schema: fn.parameters }));
} }
*/
if (enableSystemPromptCache) { if (enableSystemPromptCache) {
additionalHeaders['anthropic-beta'] = 'prompt-caching-2024-07-31'; additionalHeaders['anthropic-beta'] = 'prompt-caching-2024-07-31';
} }

View File

@ -118,8 +118,27 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFi
}); });
} }
} }
// Now replace all further messages that have the role 'system' with the role 'user'. (or all if we're not using one) // Now replace all further messages that have the role 'system' with the role 'user'. (or all if we're not using one)
messages.forEach((message) => { messages.forEach((message) => {
if (message.role === 'assistant' && message.tool_calls) {
message.content = message.tool_calls.map((tc) => ({
type: 'tool_use',
id: tc.id,
name: tc.function.name,
input: tc.function.arguments,
}));
}
if (message.role === 'tool') {
message.role = 'user';
message.content = [{
type: 'tool_result',
tool_use_id: message.tool_call_id,
content: message.content,
}];
}
if (message.role === 'system') { if (message.role === 'system') {
if (userName && message.name === 'example_user') { if (userName && message.name === 'example_user') {
message.content = `${userName}: ${message.content}`; message.content = `${userName}: ${message.content}`;
@ -128,13 +147,80 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFi
message.content = `${charName}: ${message.content}`; message.content = `${charName}: ${message.content}`;
} }
message.role = 'user'; message.role = 'user';
// Delete name here so it doesn't get added later
delete message.name;
} }
// Convert everything to an array of it would be easier to work with
if (typeof message.content === 'string') {
// Take care of name properties since claude messages don't support them
if (message.name) {
message.content = `${message.name}: ${message.content}`;
}
message.content = [{ type: 'text', text: message.content }];
} else if (Array.isArray(message.content)) {
message.content = message.content.map((content) => {
if (content.type === 'image_url') {
const imageEntry = content?.image_url;
const imageData = imageEntry?.url;
const mimeType = imageData?.split(';')?.[0].split(':')?.[1];
const base64Data = imageData?.split(',')?.[1];
return {
type: 'image',
source: {
type: 'base64',
media_type: mimeType,
data: base64Data,
},
};
}
if (content.type === 'text') {
if (message.name) {
content.text = `${message.name}: ${content.text}`;
}
return content;
}
return content;
}); });
}
// Remove offending properties
delete message.name;
delete message.tool_calls;
delete message.tool_call_id;
});
// Images in assistant messages should be moved to the next user message
for (let i = 0; i < messages.length; i++) {
if (messages[i].role === 'assistant' && messages[i].content.some(c => c.type === 'image')) {
// Find the next user message
let j = i + 1;
while (j < messages.length && messages[j].role !== 'user') {
j++;
}
// Move the images
if (j >= messages.length) {
// If there is no user message after the assistant message, add a new one
messages.splice(i + 1, 0, { role: 'user', content: [] });
}
messages[j].content.push(...messages[i].content.filter(c => c.type === 'image'));
messages[i].content = messages[i].content.filter(c => c.type !== 'image');
}
}
// Shouldn't be conditional anymore, messages api expects the last role to be user unless we're explicitly prefilling // Shouldn't be conditional anymore, messages api expects the last role to be user unless we're explicitly prefilling
if (prefillString) { if (prefillString) {
messages.push({ messages.push({
role: 'assistant', role: 'assistant',
// Dangling whitespace are not allowed for prefilling
content: prefillString.trimEnd(), content: prefillString.trimEnd(),
}); });
} }
@ -143,50 +229,11 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, humanMsgFi
// Also handle multi-modality, holy slop. // Also handle multi-modality, holy slop.
let mergedMessages = []; let mergedMessages = [];
messages.forEach((message) => { messages.forEach((message) => {
const imageEntry = message.content?.[1]?.image_url;
const imageData = imageEntry?.url;
const mimeType = imageData?.split(';')?.[0].split(':')?.[1];
const base64Data = imageData?.split(',')?.[1];
// Take care of name properties since claude messages don't support them
if (message.name) {
if (Array.isArray(message.content)) {
message.content[0].text = `${message.name}: ${message.content[0].text}`;
} else {
message.content = `${message.name}: ${message.content}`;
}
delete message.name;
}
if (mergedMessages.length > 0 && mergedMessages[mergedMessages.length - 1].role === message.role) { if (mergedMessages.length > 0 && mergedMessages[mergedMessages.length - 1].role === message.role) {
if (Array.isArray(message.content)) { mergedMessages[mergedMessages.length - 1].content.push(...message.content);
if (Array.isArray(mergedMessages[mergedMessages.length - 1].content)) {
mergedMessages[mergedMessages.length - 1].content[0].text += '\n\n' + message.content[0].text;
} else {
mergedMessages[mergedMessages.length - 1].content += '\n\n' + message.content[0].text;
}
} else {
if (Array.isArray(mergedMessages[mergedMessages.length - 1].content)) {
mergedMessages[mergedMessages.length - 1].content[0].text += '\n\n' + message.content;
} else {
mergedMessages[mergedMessages.length - 1].content += '\n\n' + message.content;
}
}
} else { } else {
mergedMessages.push(message); mergedMessages.push(message);
} }
if (imageData) {
mergedMessages[mergedMessages.length - 1].content = [
{ type: 'text', text: mergedMessages[mergedMessages.length - 1].content[0]?.text || mergedMessages[mergedMessages.length - 1].content },
{
type: 'image', source: {
type: 'base64',
media_type: mimeType,
data: base64Data,
},
},
];
}
}); });
return { messages: mergedMessages, systemPrompt: systemPrompt.trim() }; return { messages: mergedMessages, systemPrompt: systemPrompt.trim() };