refactor and rework palm request to work with the 'content' format and added an endpoint for googles tokenizer

This commit is contained in:
based
2023-12-14 15:49:50 +10:00
parent be396991de
commit e26159c00d
8 changed files with 108 additions and 46 deletions

View File

Before

Width:  |  Height:  |  Size: 4.0 KiB

After

Width:  |  Height:  |  Size: 4.0 KiB

View File

@ -1457,7 +1457,7 @@ async function sendOpenAIRequest(type, messages, signal) {
replaceItemizedPromptText(messageId, messages); replaceItemizedPromptText(messageId, messages);
} }
if (isAI21 || isGoogle) { if (isAI21) {
const joinedMsgs = messages.reduce((acc, obj) => { const joinedMsgs = messages.reduce((acc, obj) => {
const prefix = prefixMap[obj.role]; const prefix = prefixMap[obj.role];
return acc + (prefix ? (selected_group ? '\n' : prefix + ' ') : '') + obj.content + '\n'; return acc + (prefix ? (selected_group ? '\n' : prefix + ' ') : '') + obj.content + '\n';

View File

@ -376,6 +376,10 @@ export function getTokenizerModel() {
} }
} }
if(oai_settings.chat_completion_source == chat_completion_sources.MAKERSUITE) {
return oai_settings.google_model;
}
if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) { if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) {
return claudeTokenizer; return claudeTokenizer;
} }
@ -389,6 +393,15 @@ export function getTokenizerModel() {
*/ */
export function countTokensOpenAI(messages, full = false) { export function countTokensOpenAI(messages, full = false) {
const shouldTokenizeAI21 = oai_settings.chat_completion_source === chat_completion_sources.AI21 && oai_settings.use_ai21_tokenizer; const shouldTokenizeAI21 = oai_settings.chat_completion_source === chat_completion_sources.AI21 && oai_settings.use_ai21_tokenizer;
const shouldTokenizeGoogle = oai_settings.chat_completion_source === chat_completion_sources.MAKERSUITE;
let tokenizerEndpoint = '';
if(shouldTokenizeAI21) {
tokenizerEndpoint = '/api/tokenizers/ai21/count';
} else if (shouldTokenizeGoogle) {
tokenizerEndpoint = `/api/tokenizers/google/count?model=${getTokenizerModel()}`;
} else {
tokenizerEndpoint = `/api/tokenizers/openai/count?model=${getTokenizerModel()}`;
}
const cacheObject = getTokenCacheObject(); const cacheObject = getTokenCacheObject();
if (!Array.isArray(messages)) { if (!Array.isArray(messages)) {
@ -400,7 +413,7 @@ export function countTokensOpenAI(messages, full = false) {
for (const message of messages) { for (const message of messages) {
const model = getTokenizerModel(); const model = getTokenizerModel();
if (model === 'claude' || shouldTokenizeAI21) { if (model === 'claude' || shouldTokenizeAI21 || shouldTokenizeGoogle) {
full = true; full = true;
} }
@ -416,7 +429,7 @@ export function countTokensOpenAI(messages, full = false) {
jQuery.ajax({ jQuery.ajax({
async: false, async: false,
type: 'POST', // type: 'POST', //
url: shouldTokenizeAI21 ? '/api/tokenizers/ai21/count' : `/api/tokenizers/openai/count?model=${model}`, url: tokenizerEndpoint,
data: JSON.stringify([message]), data: JSON.stringify([message]),
dataType: 'json', dataType: 'json',
contentType: 'application/json', contentType: 'application/json',

View File

@ -59,7 +59,7 @@ const {
} = require('./src/util'); } = require('./src/util');
const { ensureThumbnailCache } = require('./src/endpoints/thumbnails'); const { ensureThumbnailCache } = require('./src/endpoints/thumbnails');
const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS, getSentencepiceTokenizer, sentencepieceTokenizers } = require('./src/endpoints/tokenizers'); const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS, getSentencepiceTokenizer, sentencepieceTokenizers } = require('./src/endpoints/tokenizers');
const { convertClaudePrompt } = require('./src/chat-completion'); const { convertClaudePrompt, convertGooglePrompt } = require('./src/chat-completion');
// Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0. // Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0.
// https://github.com/nodejs/node/issues/47822#issuecomment-1564708870 // https://github.com/nodejs/node/issues/47822#issuecomment-1564708870
@ -131,7 +131,7 @@ const API_OPENAI = 'https://api.openai.com/v1';
const API_CLAUDE = 'https://api.anthropic.com/v1'; const API_CLAUDE = 'https://api.anthropic.com/v1';
const SETTINGS_FILE = './public/settings.json'; const SETTINGS_FILE = './public/settings.json';
const { DIRECTORIES, UPLOADS_PATH, PALM_SAFETY, CHAT_COMPLETION_SOURCES, AVATAR_WIDTH, AVATAR_HEIGHT } = require('./src/constants'); const { DIRECTORIES, UPLOADS_PATH, MAKERSUITE_SAFETY, CHAT_COMPLETION_SOURCES, AVATAR_WIDTH, AVATAR_HEIGHT } = require('./src/constants');
// CORS Settings // // CORS Settings //
const CORS = cors({ const CORS = cors({
@ -994,29 +994,30 @@ async function sendClaudeRequest(request, response) {
* @param {express.Request} request * @param {express.Request} request
* @param {express.Response} response * @param {express.Response} response
*/ */
async function sendPalmRequest(request, response) { async function sendMakerSuiteRequest(request, response) {
const api_key_makersuite = readSecret(SECRET_KEYS.PALM); const api_key_makersuite = readSecret(SECRET_KEYS.MAKERSUITE);
if (!api_key_makersuite) { if (!api_key_makersuite) {
console.log('Palm API key is missing.'); console.log('MakerSuite API key is missing.');
return response.status(400).send({ error: true }); return response.status(400).send({ error: true });
} }
const body = { const generationConfig = {
prompt: {
text: request.body.messages,
},
stopSequences: request.body.stop, stopSequences: request.body.stop,
safetySettings: PALM_SAFETY, candidateCount: 1,
maxOutputTokens: request.body.max_tokens,
temperature: request.body.temperature, temperature: request.body.temperature,
topP: request.body.top_p, topP: request.body.top_p,
topK: request.body.top_k || undefined, topK: request.body.top_k || undefined,
maxOutputTokens: request.body.max_tokens,
candidate_count: 1,
}; };
console.log('Palm request:', body); const body = {
contents: convertGooglePrompt(request.body.messages),
safetySettings: MAKERSUITE_SAFETY,
generationConfig: generationConfig,
};
const google_model = request.body.model;
try { try {
const controller = new AbortController(); const controller = new AbortController();
request.socket.removeAllListeners('close'); request.socket.removeAllListeners('close');
@ -1024,7 +1025,7 @@ async function sendPalmRequest(request, response) {
controller.abort(); controller.abort();
}); });
const generateResponse = await fetch(`https://generativelanguage.googleapis.com/v1beta2/models/text-bison-001:generateText?key=${api_key_makersuite}`, { const generateResponse = await fetch(`https://generativelanguage.googleapis.com/v1beta/models/${google_model}:generateContent?key=${api_key_makersuite}`, {
body: JSON.stringify(body), body: JSON.stringify(body),
method: 'POST', method: 'POST',
headers: { headers: {
@ -1035,32 +1036,37 @@ async function sendPalmRequest(request, response) {
}); });
if (!generateResponse.ok) { if (!generateResponse.ok) {
console.log(`Palm API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`); console.log(`MakerSuite API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`);
return response.status(generateResponse.status).send({ error: true }); return response.status(generateResponse.status).send({ error: true });
} }
const generateResponseJson = await generateResponse.json(); const generateResponseJson = await generateResponse.json();
const responseText = generateResponseJson?.candidates[0]?.output;
if (!responseText) { const candidates = generateResponseJson?.candidates;
console.log('Palm API returned no response', generateResponseJson); if (!candidates || candidates.length === 0) {
let message = `Palm API returned no response: ${JSON.stringify(generateResponseJson)}`; let message = 'MakerSuite API returned no candidate';
console.log(message, generateResponseJson);
// Check for filters if (generateResponseJson?.promptFeedback?.blockReason) {
if (generateResponseJson?.filters[0]?.message) { message += `\nPrompt was blocked due to : ${generateResponseJson.promptFeedback.blockReason}`;
message = `Palm filter triggered: ${generateResponseJson.filters[0].message}`;
} }
return response.send({ error: { message } }); return response.send({ error: { message } });
} }
console.log('Palm response:', responseText); const responseContent = candidates[0].content;
const responseText = responseContent.parts[0].text;
if (!responseText) {
let message = 'MakerSuite Candidate text empty';
console.log(message, generateResponseJson);
return response.send({ error: { message } });
}
console.log('MakerSuite response:', responseText);
// Wrap it back to OAI format // Wrap it back to OAI format
const reply = { choices: [{ 'message': { 'content': responseText } }] }; const reply = { choices: [{ 'message': { 'content': responseText } }] };
return response.send(reply); return response.send(reply);
} catch (error) { } catch (error) {
console.log('Error communicating with Palm API: ', error); console.log('Error communicating with MakerSuite API: ', error);
if (!response.headersSent) { if (!response.headersSent) {
return response.status(500).send({ error: true }); return response.status(500).send({ error: true });
} }
@ -1074,7 +1080,7 @@ app.post('/generate_openai', jsonParser, function (request, response_generate_op
case CHAT_COMPLETION_SOURCES.CLAUDE: return sendClaudeRequest(request, response_generate_openai); case CHAT_COMPLETION_SOURCES.CLAUDE: return sendClaudeRequest(request, response_generate_openai);
case CHAT_COMPLETION_SOURCES.SCALE: return sendScaleRequest(request, response_generate_openai); case CHAT_COMPLETION_SOURCES.SCALE: return sendScaleRequest(request, response_generate_openai);
case CHAT_COMPLETION_SOURCES.AI21: return sendAI21Request(request, response_generate_openai); case CHAT_COMPLETION_SOURCES.AI21: return sendAI21Request(request, response_generate_openai);
case CHAT_COMPLETION_SOURCES.PALM: return sendPalmRequest(request, response_generate_openai); case CHAT_COMPLETION_SOURCES.MAKERSUITE: return sendMakerSuiteRequest(request, response_generate_openai);
} }
let api_url; let api_url;

View File

@ -72,6 +72,36 @@ function convertClaudePrompt(messages, addHumanPrefix, addAssistantPostfix, with
return requestPrompt; return requestPrompt;
} }
function convertGooglePrompt(messages) {
const contents = [];
let lastRole = '';
let currentText = '';
messages.forEach((message, index) => {
const role = message.role === 'assistant' ? 'model' : 'user';
if (lastRole === role) {
currentText += '\n\n' + message.content;
} else {
if (currentText !== '') {
contents.push({
parts: [{ text: currentText.trim() }],
role: lastRole,
});
}
currentText = message.content;
lastRole = role;
}
if (index === messages.length - 1) {
contents.push({
parts: [{ text: currentText.trim() }],
role: lastRole,
});
}
});
return contents;
}
module.exports = { module.exports = {
convertClaudePrompt, convertClaudePrompt,
convertGooglePrompt,
}; };

View File

@ -105,29 +105,21 @@ const UNSAFE_EXTENSIONS = [
'.ws', '.ws',
]; ];
const PALM_SAFETY = [ const MAKERSUITE_SAFETY = [
{ {
category: 'HARM_CATEGORY_DEROGATORY', category: 'HARM_CATEGORY_HARASSMENT',
threshold: 'BLOCK_NONE', threshold: 'BLOCK_NONE',
}, },
{ {
category: 'HARM_CATEGORY_TOXICITY', category: 'HARM_CATEGORY_HATE_SPEECH',
threshold: 'BLOCK_NONE', threshold: 'BLOCK_NONE',
}, },
{ {
category: 'HARM_CATEGORY_VIOLENCE', category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
threshold: 'BLOCK_NONE', threshold: 'BLOCK_NONE',
}, },
{ {
category: 'HARM_CATEGORY_SEXUAL', category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
threshold: 'BLOCK_NONE',
},
{
category: 'HARM_CATEGORY_MEDICAL',
threshold: 'BLOCK_NONE',
},
{
category: 'HARM_CATEGORY_DANGEROUS',
threshold: 'BLOCK_NONE', threshold: 'BLOCK_NONE',
}, },
]; ];
@ -139,7 +131,7 @@ const CHAT_COMPLETION_SOURCES = {
SCALE: 'scale', SCALE: 'scale',
OPENROUTER: 'openrouter', OPENROUTER: 'openrouter',
AI21: 'ai21', AI21: 'ai21',
PALM: 'palm', MAKERSUITE: 'makersuite',
}; };
const UPLOADS_PATH = './uploads'; const UPLOADS_PATH = './uploads';
@ -160,7 +152,7 @@ module.exports = {
DIRECTORIES, DIRECTORIES,
UNSAFE_EXTENSIONS, UNSAFE_EXTENSIONS,
UPLOADS_PATH, UPLOADS_PATH,
PALM_SAFETY, MAKERSUITE_SAFETY,
TEXTGEN_TYPES, TEXTGEN_TYPES,
CHAT_COMPLETION_SOURCES, CHAT_COMPLETION_SOURCES,
AVATAR_WIDTH, AVATAR_WIDTH,

View File

@ -387,6 +387,27 @@ router.post('/ai21/count', jsonParser, async function (req, res) {
} }
}); });
router.post('/google/count', jsonParser, async function (req, res) {
if (!req.body) return res.sendStatus(400);
const options = {
method: 'POST',
headers: {
accept: 'application/json',
'content-type': 'application/json',
},
body: JSON.stringify({ prompt: { text: req.body[0].content } }),
};
try {
const response = await fetch(`https://generativelanguage.googleapis.com/v1beta/models/${req.query.model}:countTextTokens?key=${readSecret(SECRET_KEYS.MAKERSUITE)}`, options);
const data = await response.json();
console.log(data)
return res.send({ 'token_count': data?.tokenCount || 0 });
} catch (err) {
console.error(err);
return res.send({ 'token_count': 0 });
}
});
router.post('/llama/encode', jsonParser, createSentencepieceEncodingHandler(spp_llama)); router.post('/llama/encode', jsonParser, createSentencepieceEncodingHandler(spp_llama));
router.post('/nerdstash/encode', jsonParser, createSentencepieceEncodingHandler(spp_nerd)); router.post('/nerdstash/encode', jsonParser, createSentencepieceEncodingHandler(spp_nerd));
router.post('/nerdstash_v2/encode', jsonParser, createSentencepieceEncodingHandler(spp_nerd_v2)); router.post('/nerdstash_v2/encode', jsonParser, createSentencepieceEncodingHandler(spp_nerd_v2));

View File

@ -14,7 +14,7 @@ async function getPaLMVector(text) {
throw new Error('No PaLM key found'); throw new Error('No PaLM key found');
} }
const response = await fetch(`https://generativelanguage.googleapis.com/v1beta2/models/embedding-gecko-001:embedText?key=${key}`, { const response = await fetch(`https://generativelanguage.googleapis.com/v1beta/models/embedding-gecko-001:embedText?key=${key}`, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',