Split Custom OAI prompt post-processing modes

This commit is contained in:
Cohee
2024-10-05 16:09:39 +03:00
parent 0637223bc2
commit 3b4a455ef8
5 changed files with 101 additions and 9 deletions

View File

@ -4,7 +4,7 @@ const fetch = require('node-fetch').default;
const { jsonParser } = require('../../express-common');
const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY, OPENROUTER_HEADERS } = require('../../constants');
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4, mergeObjectWithYaml, excludeKeysByYaml, color } = require('../../util');
const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertAI21Messages } = require('../../prompt-converters');
const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertAI21Messages, mergeMessages } = require('../../prompt-converters');
const CohereStream = require('../../cohere-stream');
const { readSecret, SECRET_KEYS } = require('../secrets');
@ -31,8 +31,11 @@ const API_AI21 = 'https://api.ai21.com/studio/v1';
*/
function postProcessPrompt(messages, type, charName, userName) {
switch (type) {
case 'merge':
case 'claude':
return convertClaudeMessages(messages, '', false, '', charName, userName).messages;
return mergeMessages(messages, charName, userName, false);
case 'strict':
return mergeMessages(messages, charName, userName, true);
default:
return messages;
}
@ -902,7 +905,7 @@ router.post('/generate', jsonParser, function (request, response) {
apiKey = readSecret(request.user.directories, SECRET_KEYS.PERPLEXITY);
headers = {};
bodyParams = {};
request.body.messages = postProcessPrompt(request.body.messages, 'claude', request.body.char_name, request.body.user_name);
request.body.messages = postProcessPrompt(request.body.messages, 'strict', request.body.char_name, request.body.user_name);
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.GROQ) {
apiUrl = API_GROQ;
apiKey = readSecret(request.user.directories, SECRET_KEYS.GROQ);