diff --git a/public/img/palm.svg b/public/img/makersuite.svg similarity index 100% rename from public/img/palm.svg rename to public/img/makersuite.svg diff --git a/public/scripts/openai.js b/public/scripts/openai.js index 644f4f195..015dd66a6 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -1457,7 +1457,7 @@ async function sendOpenAIRequest(type, messages, signal) { replaceItemizedPromptText(messageId, messages); } - if (isAI21 || isGoogle) { + if (isAI21) { const joinedMsgs = messages.reduce((acc, obj) => { const prefix = prefixMap[obj.role]; return acc + (prefix ? (selected_group ? '\n' : prefix + ' ') : '') + obj.content + '\n'; diff --git a/public/scripts/tokenizers.js b/public/scripts/tokenizers.js index decd0f919..7600a0909 100644 --- a/public/scripts/tokenizers.js +++ b/public/scripts/tokenizers.js @@ -376,6 +376,10 @@ export function getTokenizerModel() { } } + if(oai_settings.chat_completion_source == chat_completion_sources.MAKERSUITE) { + return oai_settings.google_model; + } + if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) { return claudeTokenizer; } @@ -389,6 +393,15 @@ export function getTokenizerModel() { */ export function countTokensOpenAI(messages, full = false) { const shouldTokenizeAI21 = oai_settings.chat_completion_source === chat_completion_sources.AI21 && oai_settings.use_ai21_tokenizer; + const shouldTokenizeGoogle = oai_settings.chat_completion_source === chat_completion_sources.MAKERSUITE; + let tokenizerEndpoint = ''; + if(shouldTokenizeAI21) { + tokenizerEndpoint = '/api/tokenizers/ai21/count'; + } else if (shouldTokenizeGoogle) { + tokenizerEndpoint = `/api/tokenizers/google/count?model=${getTokenizerModel()}`; + } else { + tokenizerEndpoint = `/api/tokenizers/openai/count?model=${getTokenizerModel()}`; + } const cacheObject = getTokenCacheObject(); if (!Array.isArray(messages)) { @@ -400,7 +413,7 @@ export function countTokensOpenAI(messages, full = false) { for (const message of messages) { const model = getTokenizerModel(); - if (model === 'claude' || shouldTokenizeAI21) { + if (model === 'claude' || shouldTokenizeAI21 || shouldTokenizeGoogle) { full = true; } @@ -416,7 +429,7 @@ export function countTokensOpenAI(messages, full = false) { jQuery.ajax({ async: false, type: 'POST', // - url: shouldTokenizeAI21 ? '/api/tokenizers/ai21/count' : `/api/tokenizers/openai/count?model=${model}`, + url: tokenizerEndpoint, data: JSON.stringify([message]), dataType: 'json', contentType: 'application/json', diff --git a/server.js b/server.js index 6da29c278..ac56f54fb 100644 --- a/server.js +++ b/server.js @@ -59,7 +59,7 @@ const { } = require('./src/util'); const { ensureThumbnailCache } = require('./src/endpoints/thumbnails'); const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS, getSentencepiceTokenizer, sentencepieceTokenizers } = require('./src/endpoints/tokenizers'); -const { convertClaudePrompt } = require('./src/chat-completion'); +const { convertClaudePrompt, convertGooglePrompt } = require('./src/chat-completion'); // Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0. // https://github.com/nodejs/node/issues/47822#issuecomment-1564708870 @@ -131,7 +131,7 @@ const API_OPENAI = 'https://api.openai.com/v1'; const API_CLAUDE = 'https://api.anthropic.com/v1'; const SETTINGS_FILE = './public/settings.json'; -const { DIRECTORIES, UPLOADS_PATH, PALM_SAFETY, CHAT_COMPLETION_SOURCES, AVATAR_WIDTH, AVATAR_HEIGHT } = require('./src/constants'); +const { DIRECTORIES, UPLOADS_PATH, MAKERSUITE_SAFETY, CHAT_COMPLETION_SOURCES, AVATAR_WIDTH, AVATAR_HEIGHT } = require('./src/constants'); // CORS Settings // const CORS = cors({ @@ -994,29 +994,30 @@ async function sendClaudeRequest(request, response) { * @param {express.Request} request * @param {express.Response} response */ -async function sendPalmRequest(request, response) { - const api_key_makersuite = readSecret(SECRET_KEYS.PALM); +async function sendMakerSuiteRequest(request, response) { + const api_key_makersuite = readSecret(SECRET_KEYS.MAKERSUITE); if (!api_key_makersuite) { - console.log('Palm API key is missing.'); + console.log('MakerSuite API key is missing.'); return response.status(400).send({ error: true }); } - const body = { - prompt: { - text: request.body.messages, - }, + const generationConfig = { stopSequences: request.body.stop, - safetySettings: PALM_SAFETY, + candidateCount: 1, + maxOutputTokens: request.body.max_tokens, temperature: request.body.temperature, topP: request.body.top_p, topK: request.body.top_k || undefined, - maxOutputTokens: request.body.max_tokens, - candidate_count: 1, }; - console.log('Palm request:', body); + const body = { + contents: convertGooglePrompt(request.body.messages), + safetySettings: MAKERSUITE_SAFETY, + generationConfig: generationConfig, + }; + const google_model = request.body.model; try { const controller = new AbortController(); request.socket.removeAllListeners('close'); @@ -1024,7 +1025,7 @@ async function sendPalmRequest(request, response) { controller.abort(); }); - const generateResponse = await fetch(`https://generativelanguage.googleapis.com/v1beta2/models/text-bison-001:generateText?key=${api_key_makersuite}`, { + const generateResponse = await fetch(`https://generativelanguage.googleapis.com/v1beta/models/${google_model}:generateContent?key=${api_key_makersuite}`, { body: JSON.stringify(body), method: 'POST', headers: { @@ -1035,32 +1036,37 @@ async function sendPalmRequest(request, response) { }); if (!generateResponse.ok) { - console.log(`Palm API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`); + console.log(`MakerSuite API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`); return response.status(generateResponse.status).send({ error: true }); } const generateResponseJson = await generateResponse.json(); - const responseText = generateResponseJson?.candidates[0]?.output; - if (!responseText) { - console.log('Palm API returned no response', generateResponseJson); - let message = `Palm API returned no response: ${JSON.stringify(generateResponseJson)}`; - - // Check for filters - if (generateResponseJson?.filters[0]?.message) { - message = `Palm filter triggered: ${generateResponseJson.filters[0].message}`; + const candidates = generateResponseJson?.candidates; + if (!candidates || candidates.length === 0) { + let message = 'MakerSuite API returned no candidate'; + console.log(message, generateResponseJson); + if (generateResponseJson?.promptFeedback?.blockReason) { + message += `\nPrompt was blocked due to : ${generateResponseJson.promptFeedback.blockReason}`; } - return response.send({ error: { message } }); } - console.log('Palm response:', responseText); + const responseContent = candidates[0].content; + const responseText = responseContent.parts[0].text; + if (!responseText) { + let message = 'MakerSuite Candidate text empty'; + console.log(message, generateResponseJson); + return response.send({ error: { message } }); + } + + console.log('MakerSuite response:', responseText); // Wrap it back to OAI format const reply = { choices: [{ 'message': { 'content': responseText } }] }; return response.send(reply); } catch (error) { - console.log('Error communicating with Palm API: ', error); + console.log('Error communicating with MakerSuite API: ', error); if (!response.headersSent) { return response.status(500).send({ error: true }); } @@ -1074,7 +1080,7 @@ app.post('/generate_openai', jsonParser, function (request, response_generate_op case CHAT_COMPLETION_SOURCES.CLAUDE: return sendClaudeRequest(request, response_generate_openai); case CHAT_COMPLETION_SOURCES.SCALE: return sendScaleRequest(request, response_generate_openai); case CHAT_COMPLETION_SOURCES.AI21: return sendAI21Request(request, response_generate_openai); - case CHAT_COMPLETION_SOURCES.PALM: return sendPalmRequest(request, response_generate_openai); + case CHAT_COMPLETION_SOURCES.MAKERSUITE: return sendMakerSuiteRequest(request, response_generate_openai); } let api_url; diff --git a/src/chat-completion.js b/src/chat-completion.js index 4fc21a550..d1f97f8a3 100644 --- a/src/chat-completion.js +++ b/src/chat-completion.js @@ -72,6 +72,36 @@ function convertClaudePrompt(messages, addHumanPrefix, addAssistantPostfix, with return requestPrompt; } +function convertGooglePrompt(messages) { + const contents = []; + let lastRole = ''; + let currentText = ''; + + messages.forEach((message, index) => { + const role = message.role === 'assistant' ? 'model' : 'user'; + if (lastRole === role) { + currentText += '\n\n' + message.content; + } else { + if (currentText !== '') { + contents.push({ + parts: [{ text: currentText.trim() }], + role: lastRole, + }); + } + currentText = message.content; + lastRole = role; + } + if (index === messages.length - 1) { + contents.push({ + parts: [{ text: currentText.trim() }], + role: lastRole, + }); + } + }); + return contents; +} + module.exports = { convertClaudePrompt, + convertGooglePrompt, }; diff --git a/src/constants.js b/src/constants.js index 32ea6fad5..7151ada24 100644 --- a/src/constants.js +++ b/src/constants.js @@ -105,29 +105,21 @@ const UNSAFE_EXTENSIONS = [ '.ws', ]; -const PALM_SAFETY = [ +const MAKERSUITE_SAFETY = [ { - category: 'HARM_CATEGORY_DEROGATORY', + category: 'HARM_CATEGORY_HARASSMENT', threshold: 'BLOCK_NONE', }, { - category: 'HARM_CATEGORY_TOXICITY', + category: 'HARM_CATEGORY_HATE_SPEECH', threshold: 'BLOCK_NONE', }, { - category: 'HARM_CATEGORY_VIOLENCE', + category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', threshold: 'BLOCK_NONE', }, { - category: 'HARM_CATEGORY_SEXUAL', - threshold: 'BLOCK_NONE', - }, - { - category: 'HARM_CATEGORY_MEDICAL', - threshold: 'BLOCK_NONE', - }, - { - category: 'HARM_CATEGORY_DANGEROUS', + category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_NONE', }, ]; @@ -139,7 +131,7 @@ const CHAT_COMPLETION_SOURCES = { SCALE: 'scale', OPENROUTER: 'openrouter', AI21: 'ai21', - PALM: 'palm', + MAKERSUITE: 'makersuite', }; const UPLOADS_PATH = './uploads'; @@ -160,7 +152,7 @@ module.exports = { DIRECTORIES, UNSAFE_EXTENSIONS, UPLOADS_PATH, - PALM_SAFETY, + MAKERSUITE_SAFETY, TEXTGEN_TYPES, CHAT_COMPLETION_SOURCES, AVATAR_WIDTH, diff --git a/src/endpoints/tokenizers.js b/src/endpoints/tokenizers.js index a81779d97..096b0f093 100644 --- a/src/endpoints/tokenizers.js +++ b/src/endpoints/tokenizers.js @@ -387,6 +387,27 @@ router.post('/ai21/count', jsonParser, async function (req, res) { } }); +router.post('/google/count', jsonParser, async function (req, res) { + if (!req.body) return res.sendStatus(400); + const options = { + method: 'POST', + headers: { + accept: 'application/json', + 'content-type': 'application/json', + }, + body: JSON.stringify({ prompt: { text: req.body[0].content } }), + }; + try { + const response = await fetch(`https://generativelanguage.googleapis.com/v1beta/models/${req.query.model}:countTextTokens?key=${readSecret(SECRET_KEYS.MAKERSUITE)}`, options); + const data = await response.json(); + console.log(data) + return res.send({ 'token_count': data?.tokenCount || 0 }); + } catch (err) { + console.error(err); + return res.send({ 'token_count': 0 }); + } +}); + router.post('/llama/encode', jsonParser, createSentencepieceEncodingHandler(spp_llama)); router.post('/nerdstash/encode', jsonParser, createSentencepieceEncodingHandler(spp_nerd)); router.post('/nerdstash_v2/encode', jsonParser, createSentencepieceEncodingHandler(spp_nerd_v2)); diff --git a/src/palm-vectors.js b/src/palm-vectors.js index 788b474cd..b4e6a68bd 100644 --- a/src/palm-vectors.js +++ b/src/palm-vectors.js @@ -14,7 +14,7 @@ async function getPaLMVector(text) { throw new Error('No PaLM key found'); } - const response = await fetch(`https://generativelanguage.googleapis.com/v1beta2/models/embedding-gecko-001:embedText?key=${key}`, { + const response = await fetch(`https://generativelanguage.googleapis.com/v1beta/models/embedding-gecko-001:embedText?key=${key}`, { method: 'POST', headers: { 'Content-Type': 'application/json',