mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-06-05 21:59:27 +02:00
refactor and rework palm request to work with the 'content' format and added an endpoint for googles tokenizer
This commit is contained in:
Before Width: | Height: | Size: 4.0 KiB After Width: | Height: | Size: 4.0 KiB |
@ -1457,7 +1457,7 @@ async function sendOpenAIRequest(type, messages, signal) {
|
|||||||
replaceItemizedPromptText(messageId, messages);
|
replaceItemizedPromptText(messageId, messages);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isAI21 || isGoogle) {
|
if (isAI21) {
|
||||||
const joinedMsgs = messages.reduce((acc, obj) => {
|
const joinedMsgs = messages.reduce((acc, obj) => {
|
||||||
const prefix = prefixMap[obj.role];
|
const prefix = prefixMap[obj.role];
|
||||||
return acc + (prefix ? (selected_group ? '\n' : prefix + ' ') : '') + obj.content + '\n';
|
return acc + (prefix ? (selected_group ? '\n' : prefix + ' ') : '') + obj.content + '\n';
|
||||||
|
@ -376,6 +376,10 @@ export function getTokenizerModel() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(oai_settings.chat_completion_source == chat_completion_sources.MAKERSUITE) {
|
||||||
|
return oai_settings.google_model;
|
||||||
|
}
|
||||||
|
|
||||||
if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) {
|
if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) {
|
||||||
return claudeTokenizer;
|
return claudeTokenizer;
|
||||||
}
|
}
|
||||||
@ -389,6 +393,15 @@ export function getTokenizerModel() {
|
|||||||
*/
|
*/
|
||||||
export function countTokensOpenAI(messages, full = false) {
|
export function countTokensOpenAI(messages, full = false) {
|
||||||
const shouldTokenizeAI21 = oai_settings.chat_completion_source === chat_completion_sources.AI21 && oai_settings.use_ai21_tokenizer;
|
const shouldTokenizeAI21 = oai_settings.chat_completion_source === chat_completion_sources.AI21 && oai_settings.use_ai21_tokenizer;
|
||||||
|
const shouldTokenizeGoogle = oai_settings.chat_completion_source === chat_completion_sources.MAKERSUITE;
|
||||||
|
let tokenizerEndpoint = '';
|
||||||
|
if(shouldTokenizeAI21) {
|
||||||
|
tokenizerEndpoint = '/api/tokenizers/ai21/count';
|
||||||
|
} else if (shouldTokenizeGoogle) {
|
||||||
|
tokenizerEndpoint = `/api/tokenizers/google/count?model=${getTokenizerModel()}`;
|
||||||
|
} else {
|
||||||
|
tokenizerEndpoint = `/api/tokenizers/openai/count?model=${getTokenizerModel()}`;
|
||||||
|
}
|
||||||
const cacheObject = getTokenCacheObject();
|
const cacheObject = getTokenCacheObject();
|
||||||
|
|
||||||
if (!Array.isArray(messages)) {
|
if (!Array.isArray(messages)) {
|
||||||
@ -400,7 +413,7 @@ export function countTokensOpenAI(messages, full = false) {
|
|||||||
for (const message of messages) {
|
for (const message of messages) {
|
||||||
const model = getTokenizerModel();
|
const model = getTokenizerModel();
|
||||||
|
|
||||||
if (model === 'claude' || shouldTokenizeAI21) {
|
if (model === 'claude' || shouldTokenizeAI21 || shouldTokenizeGoogle) {
|
||||||
full = true;
|
full = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -416,7 +429,7 @@ export function countTokensOpenAI(messages, full = false) {
|
|||||||
jQuery.ajax({
|
jQuery.ajax({
|
||||||
async: false,
|
async: false,
|
||||||
type: 'POST', //
|
type: 'POST', //
|
||||||
url: shouldTokenizeAI21 ? '/api/tokenizers/ai21/count' : `/api/tokenizers/openai/count?model=${model}`,
|
url: tokenizerEndpoint,
|
||||||
data: JSON.stringify([message]),
|
data: JSON.stringify([message]),
|
||||||
dataType: 'json',
|
dataType: 'json',
|
||||||
contentType: 'application/json',
|
contentType: 'application/json',
|
||||||
|
60
server.js
60
server.js
@ -59,7 +59,7 @@ const {
|
|||||||
} = require('./src/util');
|
} = require('./src/util');
|
||||||
const { ensureThumbnailCache } = require('./src/endpoints/thumbnails');
|
const { ensureThumbnailCache } = require('./src/endpoints/thumbnails');
|
||||||
const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS, getSentencepiceTokenizer, sentencepieceTokenizers } = require('./src/endpoints/tokenizers');
|
const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS, getSentencepiceTokenizer, sentencepieceTokenizers } = require('./src/endpoints/tokenizers');
|
||||||
const { convertClaudePrompt } = require('./src/chat-completion');
|
const { convertClaudePrompt, convertGooglePrompt } = require('./src/chat-completion');
|
||||||
|
|
||||||
// Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0.
|
// Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0.
|
||||||
// https://github.com/nodejs/node/issues/47822#issuecomment-1564708870
|
// https://github.com/nodejs/node/issues/47822#issuecomment-1564708870
|
||||||
@ -131,7 +131,7 @@ const API_OPENAI = 'https://api.openai.com/v1';
|
|||||||
const API_CLAUDE = 'https://api.anthropic.com/v1';
|
const API_CLAUDE = 'https://api.anthropic.com/v1';
|
||||||
|
|
||||||
const SETTINGS_FILE = './public/settings.json';
|
const SETTINGS_FILE = './public/settings.json';
|
||||||
const { DIRECTORIES, UPLOADS_PATH, PALM_SAFETY, CHAT_COMPLETION_SOURCES, AVATAR_WIDTH, AVATAR_HEIGHT } = require('./src/constants');
|
const { DIRECTORIES, UPLOADS_PATH, MAKERSUITE_SAFETY, CHAT_COMPLETION_SOURCES, AVATAR_WIDTH, AVATAR_HEIGHT } = require('./src/constants');
|
||||||
|
|
||||||
// CORS Settings //
|
// CORS Settings //
|
||||||
const CORS = cors({
|
const CORS = cors({
|
||||||
@ -994,29 +994,30 @@ async function sendClaudeRequest(request, response) {
|
|||||||
* @param {express.Request} request
|
* @param {express.Request} request
|
||||||
* @param {express.Response} response
|
* @param {express.Response} response
|
||||||
*/
|
*/
|
||||||
async function sendPalmRequest(request, response) {
|
async function sendMakerSuiteRequest(request, response) {
|
||||||
const api_key_makersuite = readSecret(SECRET_KEYS.PALM);
|
const api_key_makersuite = readSecret(SECRET_KEYS.MAKERSUITE);
|
||||||
|
|
||||||
if (!api_key_makersuite) {
|
if (!api_key_makersuite) {
|
||||||
console.log('Palm API key is missing.');
|
console.log('MakerSuite API key is missing.');
|
||||||
return response.status(400).send({ error: true });
|
return response.status(400).send({ error: true });
|
||||||
}
|
}
|
||||||
|
|
||||||
const body = {
|
const generationConfig = {
|
||||||
prompt: {
|
|
||||||
text: request.body.messages,
|
|
||||||
},
|
|
||||||
stopSequences: request.body.stop,
|
stopSequences: request.body.stop,
|
||||||
safetySettings: PALM_SAFETY,
|
candidateCount: 1,
|
||||||
|
maxOutputTokens: request.body.max_tokens,
|
||||||
temperature: request.body.temperature,
|
temperature: request.body.temperature,
|
||||||
topP: request.body.top_p,
|
topP: request.body.top_p,
|
||||||
topK: request.body.top_k || undefined,
|
topK: request.body.top_k || undefined,
|
||||||
maxOutputTokens: request.body.max_tokens,
|
|
||||||
candidate_count: 1,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
console.log('Palm request:', body);
|
const body = {
|
||||||
|
contents: convertGooglePrompt(request.body.messages),
|
||||||
|
safetySettings: MAKERSUITE_SAFETY,
|
||||||
|
generationConfig: generationConfig,
|
||||||
|
};
|
||||||
|
|
||||||
|
const google_model = request.body.model;
|
||||||
try {
|
try {
|
||||||
const controller = new AbortController();
|
const controller = new AbortController();
|
||||||
request.socket.removeAllListeners('close');
|
request.socket.removeAllListeners('close');
|
||||||
@ -1024,7 +1025,7 @@ async function sendPalmRequest(request, response) {
|
|||||||
controller.abort();
|
controller.abort();
|
||||||
});
|
});
|
||||||
|
|
||||||
const generateResponse = await fetch(`https://generativelanguage.googleapis.com/v1beta2/models/text-bison-001:generateText?key=${api_key_makersuite}`, {
|
const generateResponse = await fetch(`https://generativelanguage.googleapis.com/v1beta/models/${google_model}:generateContent?key=${api_key_makersuite}`, {
|
||||||
body: JSON.stringify(body),
|
body: JSON.stringify(body),
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
@ -1035,32 +1036,37 @@ async function sendPalmRequest(request, response) {
|
|||||||
});
|
});
|
||||||
|
|
||||||
if (!generateResponse.ok) {
|
if (!generateResponse.ok) {
|
||||||
console.log(`Palm API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`);
|
console.log(`MakerSuite API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`);
|
||||||
return response.status(generateResponse.status).send({ error: true });
|
return response.status(generateResponse.status).send({ error: true });
|
||||||
}
|
}
|
||||||
|
|
||||||
const generateResponseJson = await generateResponse.json();
|
const generateResponseJson = await generateResponse.json();
|
||||||
const responseText = generateResponseJson?.candidates[0]?.output;
|
|
||||||
|
|
||||||
if (!responseText) {
|
const candidates = generateResponseJson?.candidates;
|
||||||
console.log('Palm API returned no response', generateResponseJson);
|
if (!candidates || candidates.length === 0) {
|
||||||
let message = `Palm API returned no response: ${JSON.stringify(generateResponseJson)}`;
|
let message = 'MakerSuite API returned no candidate';
|
||||||
|
console.log(message, generateResponseJson);
|
||||||
// Check for filters
|
if (generateResponseJson?.promptFeedback?.blockReason) {
|
||||||
if (generateResponseJson?.filters[0]?.message) {
|
message += `\nPrompt was blocked due to : ${generateResponseJson.promptFeedback.blockReason}`;
|
||||||
message = `Palm filter triggered: ${generateResponseJson.filters[0].message}`;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return response.send({ error: { message } });
|
return response.send({ error: { message } });
|
||||||
}
|
}
|
||||||
|
|
||||||
console.log('Palm response:', responseText);
|
const responseContent = candidates[0].content;
|
||||||
|
const responseText = responseContent.parts[0].text;
|
||||||
|
if (!responseText) {
|
||||||
|
let message = 'MakerSuite Candidate text empty';
|
||||||
|
console.log(message, generateResponseJson);
|
||||||
|
return response.send({ error: { message } });
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log('MakerSuite response:', responseText);
|
||||||
|
|
||||||
// Wrap it back to OAI format
|
// Wrap it back to OAI format
|
||||||
const reply = { choices: [{ 'message': { 'content': responseText } }] };
|
const reply = { choices: [{ 'message': { 'content': responseText } }] };
|
||||||
return response.send(reply);
|
return response.send(reply);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.log('Error communicating with Palm API: ', error);
|
console.log('Error communicating with MakerSuite API: ', error);
|
||||||
if (!response.headersSent) {
|
if (!response.headersSent) {
|
||||||
return response.status(500).send({ error: true });
|
return response.status(500).send({ error: true });
|
||||||
}
|
}
|
||||||
@ -1074,7 +1080,7 @@ app.post('/generate_openai', jsonParser, function (request, response_generate_op
|
|||||||
case CHAT_COMPLETION_SOURCES.CLAUDE: return sendClaudeRequest(request, response_generate_openai);
|
case CHAT_COMPLETION_SOURCES.CLAUDE: return sendClaudeRequest(request, response_generate_openai);
|
||||||
case CHAT_COMPLETION_SOURCES.SCALE: return sendScaleRequest(request, response_generate_openai);
|
case CHAT_COMPLETION_SOURCES.SCALE: return sendScaleRequest(request, response_generate_openai);
|
||||||
case CHAT_COMPLETION_SOURCES.AI21: return sendAI21Request(request, response_generate_openai);
|
case CHAT_COMPLETION_SOURCES.AI21: return sendAI21Request(request, response_generate_openai);
|
||||||
case CHAT_COMPLETION_SOURCES.PALM: return sendPalmRequest(request, response_generate_openai);
|
case CHAT_COMPLETION_SOURCES.MAKERSUITE: return sendMakerSuiteRequest(request, response_generate_openai);
|
||||||
}
|
}
|
||||||
|
|
||||||
let api_url;
|
let api_url;
|
||||||
|
@ -72,6 +72,36 @@ function convertClaudePrompt(messages, addHumanPrefix, addAssistantPostfix, with
|
|||||||
return requestPrompt;
|
return requestPrompt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function convertGooglePrompt(messages) {
|
||||||
|
const contents = [];
|
||||||
|
let lastRole = '';
|
||||||
|
let currentText = '';
|
||||||
|
|
||||||
|
messages.forEach((message, index) => {
|
||||||
|
const role = message.role === 'assistant' ? 'model' : 'user';
|
||||||
|
if (lastRole === role) {
|
||||||
|
currentText += '\n\n' + message.content;
|
||||||
|
} else {
|
||||||
|
if (currentText !== '') {
|
||||||
|
contents.push({
|
||||||
|
parts: [{ text: currentText.trim() }],
|
||||||
|
role: lastRole,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
currentText = message.content;
|
||||||
|
lastRole = role;
|
||||||
|
}
|
||||||
|
if (index === messages.length - 1) {
|
||||||
|
contents.push({
|
||||||
|
parts: [{ text: currentText.trim() }],
|
||||||
|
role: lastRole,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return contents;
|
||||||
|
}
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
convertClaudePrompt,
|
convertClaudePrompt,
|
||||||
|
convertGooglePrompt,
|
||||||
};
|
};
|
||||||
|
@ -105,29 +105,21 @@ const UNSAFE_EXTENSIONS = [
|
|||||||
'.ws',
|
'.ws',
|
||||||
];
|
];
|
||||||
|
|
||||||
const PALM_SAFETY = [
|
const MAKERSUITE_SAFETY = [
|
||||||
{
|
{
|
||||||
category: 'HARM_CATEGORY_DEROGATORY',
|
category: 'HARM_CATEGORY_HARASSMENT',
|
||||||
threshold: 'BLOCK_NONE',
|
threshold: 'BLOCK_NONE',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
category: 'HARM_CATEGORY_TOXICITY',
|
category: 'HARM_CATEGORY_HATE_SPEECH',
|
||||||
threshold: 'BLOCK_NONE',
|
threshold: 'BLOCK_NONE',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
category: 'HARM_CATEGORY_VIOLENCE',
|
category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
|
||||||
threshold: 'BLOCK_NONE',
|
threshold: 'BLOCK_NONE',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
category: 'HARM_CATEGORY_SEXUAL',
|
category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
|
||||||
threshold: 'BLOCK_NONE',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
category: 'HARM_CATEGORY_MEDICAL',
|
|
||||||
threshold: 'BLOCK_NONE',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
category: 'HARM_CATEGORY_DANGEROUS',
|
|
||||||
threshold: 'BLOCK_NONE',
|
threshold: 'BLOCK_NONE',
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
@ -139,7 +131,7 @@ const CHAT_COMPLETION_SOURCES = {
|
|||||||
SCALE: 'scale',
|
SCALE: 'scale',
|
||||||
OPENROUTER: 'openrouter',
|
OPENROUTER: 'openrouter',
|
||||||
AI21: 'ai21',
|
AI21: 'ai21',
|
||||||
PALM: 'palm',
|
MAKERSUITE: 'makersuite',
|
||||||
};
|
};
|
||||||
|
|
||||||
const UPLOADS_PATH = './uploads';
|
const UPLOADS_PATH = './uploads';
|
||||||
@ -160,7 +152,7 @@ module.exports = {
|
|||||||
DIRECTORIES,
|
DIRECTORIES,
|
||||||
UNSAFE_EXTENSIONS,
|
UNSAFE_EXTENSIONS,
|
||||||
UPLOADS_PATH,
|
UPLOADS_PATH,
|
||||||
PALM_SAFETY,
|
MAKERSUITE_SAFETY,
|
||||||
TEXTGEN_TYPES,
|
TEXTGEN_TYPES,
|
||||||
CHAT_COMPLETION_SOURCES,
|
CHAT_COMPLETION_SOURCES,
|
||||||
AVATAR_WIDTH,
|
AVATAR_WIDTH,
|
||||||
|
@ -387,6 +387,27 @@ router.post('/ai21/count', jsonParser, async function (req, res) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
router.post('/google/count', jsonParser, async function (req, res) {
|
||||||
|
if (!req.body) return res.sendStatus(400);
|
||||||
|
const options = {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
accept: 'application/json',
|
||||||
|
'content-type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({ prompt: { text: req.body[0].content } }),
|
||||||
|
};
|
||||||
|
try {
|
||||||
|
const response = await fetch(`https://generativelanguage.googleapis.com/v1beta/models/${req.query.model}:countTextTokens?key=${readSecret(SECRET_KEYS.MAKERSUITE)}`, options);
|
||||||
|
const data = await response.json();
|
||||||
|
console.log(data)
|
||||||
|
return res.send({ 'token_count': data?.tokenCount || 0 });
|
||||||
|
} catch (err) {
|
||||||
|
console.error(err);
|
||||||
|
return res.send({ 'token_count': 0 });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
router.post('/llama/encode', jsonParser, createSentencepieceEncodingHandler(spp_llama));
|
router.post('/llama/encode', jsonParser, createSentencepieceEncodingHandler(spp_llama));
|
||||||
router.post('/nerdstash/encode', jsonParser, createSentencepieceEncodingHandler(spp_nerd));
|
router.post('/nerdstash/encode', jsonParser, createSentencepieceEncodingHandler(spp_nerd));
|
||||||
router.post('/nerdstash_v2/encode', jsonParser, createSentencepieceEncodingHandler(spp_nerd_v2));
|
router.post('/nerdstash_v2/encode', jsonParser, createSentencepieceEncodingHandler(spp_nerd_v2));
|
||||||
|
@ -14,7 +14,7 @@ async function getPaLMVector(text) {
|
|||||||
throw new Error('No PaLM key found');
|
throw new Error('No PaLM key found');
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await fetch(`https://generativelanguage.googleapis.com/v1beta2/models/embedding-gecko-001:embedText?key=${key}`, {
|
const response = await fetch(`https://generativelanguage.googleapis.com/v1beta/models/embedding-gecko-001:embedText?key=${key}`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
|
Reference in New Issue
Block a user