refactor and rework palm request to work with the 'content' format and added an endpoint for googles tokenizer

This commit is contained in:
based
2023-12-14 15:49:50 +10:00
parent be396991de
commit e26159c00d
8 changed files with 108 additions and 46 deletions

View File

@@ -59,7 +59,7 @@ const {
} = require('./src/util');
const { ensureThumbnailCache } = require('./src/endpoints/thumbnails');
const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS, getSentencepiceTokenizer, sentencepieceTokenizers } = require('./src/endpoints/tokenizers');
const { convertClaudePrompt } = require('./src/chat-completion');
const { convertClaudePrompt, convertGooglePrompt } = require('./src/chat-completion');
// Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0.
// https://github.com/nodejs/node/issues/47822#issuecomment-1564708870
@@ -131,7 +131,7 @@ const API_OPENAI = 'https://api.openai.com/v1';
const API_CLAUDE = 'https://api.anthropic.com/v1';
const SETTINGS_FILE = './public/settings.json';
const { DIRECTORIES, UPLOADS_PATH, PALM_SAFETY, CHAT_COMPLETION_SOURCES, AVATAR_WIDTH, AVATAR_HEIGHT } = require('./src/constants');
const { DIRECTORIES, UPLOADS_PATH, MAKERSUITE_SAFETY, CHAT_COMPLETION_SOURCES, AVATAR_WIDTH, AVATAR_HEIGHT } = require('./src/constants');
// CORS Settings //
const CORS = cors({
@@ -994,29 +994,30 @@ async function sendClaudeRequest(request, response) {
* @param {express.Request} request
* @param {express.Response} response
*/
async function sendPalmRequest(request, response) {
const api_key_makersuite = readSecret(SECRET_KEYS.PALM);
async function sendMakerSuiteRequest(request, response) {
const api_key_makersuite = readSecret(SECRET_KEYS.MAKERSUITE);
if (!api_key_makersuite) {
console.log('Palm API key is missing.');
console.log('MakerSuite API key is missing.');
return response.status(400).send({ error: true });
}
const body = {
prompt: {
text: request.body.messages,
},
const generationConfig = {
stopSequences: request.body.stop,
safetySettings: PALM_SAFETY,
candidateCount: 1,
maxOutputTokens: request.body.max_tokens,
temperature: request.body.temperature,
topP: request.body.top_p,
topK: request.body.top_k || undefined,
maxOutputTokens: request.body.max_tokens,
candidate_count: 1,
};
console.log('Palm request:', body);
const body = {
contents: convertGooglePrompt(request.body.messages),
safetySettings: MAKERSUITE_SAFETY,
generationConfig: generationConfig,
};
const google_model = request.body.model;
try {
const controller = new AbortController();
request.socket.removeAllListeners('close');
@@ -1024,7 +1025,7 @@ async function sendPalmRequest(request, response) {
controller.abort();
});
const generateResponse = await fetch(`https://generativelanguage.googleapis.com/v1beta2/models/text-bison-001:generateText?key=${api_key_makersuite}`, {
const generateResponse = await fetch(`https://generativelanguage.googleapis.com/v1beta/models/${google_model}:generateContent?key=${api_key_makersuite}`, {
body: JSON.stringify(body),
method: 'POST',
headers: {
@@ -1035,32 +1036,37 @@ async function sendPalmRequest(request, response) {
});
if (!generateResponse.ok) {
console.log(`Palm API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`);
console.log(`MakerSuite API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`);
return response.status(generateResponse.status).send({ error: true });
}
const generateResponseJson = await generateResponse.json();
const responseText = generateResponseJson?.candidates[0]?.output;
if (!responseText) {
console.log('Palm API returned no response', generateResponseJson);
let message = `Palm API returned no response: ${JSON.stringify(generateResponseJson)}`;
// Check for filters
if (generateResponseJson?.filters[0]?.message) {
message = `Palm filter triggered: ${generateResponseJson.filters[0].message}`;
const candidates = generateResponseJson?.candidates;
if (!candidates || candidates.length === 0) {
let message = 'MakerSuite API returned no candidate';
console.log(message, generateResponseJson);
if (generateResponseJson?.promptFeedback?.blockReason) {
message += `\nPrompt was blocked due to : ${generateResponseJson.promptFeedback.blockReason}`;
}
return response.send({ error: { message } });
}
console.log('Palm response:', responseText);
const responseContent = candidates[0].content;
const responseText = responseContent.parts[0].text;
if (!responseText) {
let message = 'MakerSuite Candidate text empty';
console.log(message, generateResponseJson);
return response.send({ error: { message } });
}
console.log('MakerSuite response:', responseText);
// Wrap it back to OAI format
const reply = { choices: [{ 'message': { 'content': responseText } }] };
return response.send(reply);
} catch (error) {
console.log('Error communicating with Palm API: ', error);
console.log('Error communicating with MakerSuite API: ', error);
if (!response.headersSent) {
return response.status(500).send({ error: true });
}
@@ -1074,7 +1080,7 @@ app.post('/generate_openai', jsonParser, function (request, response_generate_op
case CHAT_COMPLETION_SOURCES.CLAUDE: return sendClaudeRequest(request, response_generate_openai);
case CHAT_COMPLETION_SOURCES.SCALE: return sendScaleRequest(request, response_generate_openai);
case CHAT_COMPLETION_SOURCES.AI21: return sendAI21Request(request, response_generate_openai);
case CHAT_COMPLETION_SOURCES.PALM: return sendPalmRequest(request, response_generate_openai);
case CHAT_COMPLETION_SOURCES.MAKERSUITE: return sendMakerSuiteRequest(request, response_generate_openai);
}
let api_url;