mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-02-09 00:28:52 +01:00
Merge pull request #1466 from valadaptive/tokenizers-router
Use Express router for tokenizers endpoint
This commit is contained in:
commit
60083b2a35
@ -3579,7 +3579,7 @@ async function fetchJSON(url, args = {}) {
|
|||||||
app.use('/api/openai', require('./src/endpoints/openai').router);
|
app.use('/api/openai', require('./src/endpoints/openai').router);
|
||||||
|
|
||||||
// Tokenizers
|
// Tokenizers
|
||||||
require('./src/endpoints/tokenizers').registerEndpoints(app, jsonParser);
|
app.use('/api/tokenizers', require('./src/endpoints/tokenizers').router);
|
||||||
|
|
||||||
// Preset management
|
// Preset management
|
||||||
app.use('/api/presets', require('./src/endpoints/presets').router);
|
app.use('/api/presets', require('./src/endpoints/presets').router);
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
const fs = require('fs');
|
const fs = require('fs');
|
||||||
const path = require('path');
|
const path = require('path');
|
||||||
|
const express = require('express');
|
||||||
const { SentencePieceProcessor } = require('@agnai/sentencepiece-js');
|
const { SentencePieceProcessor } = require('@agnai/sentencepiece-js');
|
||||||
const tiktoken = require('@dqbd/tiktoken');
|
const tiktoken = require('@dqbd/tiktoken');
|
||||||
const { Tokenizer } = require('@agnai/web-tokenizers');
|
const { Tokenizer } = require('@agnai/web-tokenizers');
|
||||||
const { convertClaudePrompt } = require('../chat-completion');
|
const { convertClaudePrompt } = require('../chat-completion');
|
||||||
const { readSecret, SECRET_KEYS } = require('./secrets');
|
const { readSecret, SECRET_KEYS } = require('./secrets');
|
||||||
|
const { jsonParser } = require('../express-common');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @type {{[key: string]: import("@dqbd/tiktoken").Tiktoken}} Tokenizers cache
|
* @type {{[key: string]: import("@dqbd/tiktoken").Tiktoken}} Tokenizers cache
|
||||||
@ -359,183 +361,178 @@ async function loadTokenizers() {
|
|||||||
claude_tokenizer = await loadClaudeTokenizer('src/claude.json');
|
claude_tokenizer = await loadClaudeTokenizer('src/claude.json');
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
const router = express.Router();
|
||||||
* Registers the tokenization endpoints.
|
|
||||||
* @param {import('express').Express} app Express app
|
router.post('/ai21/count', jsonParser, async function (req, res) {
|
||||||
* @param {any} jsonParser JSON parser middleware
|
if (!req.body) return res.sendStatus(400);
|
||||||
*/
|
const options = {
|
||||||
function registerEndpoints(app, jsonParser) {
|
method: 'POST',
|
||||||
app.post('/api/tokenizers/ai21/count', jsonParser, async function (req, res) {
|
headers: {
|
||||||
|
accept: 'application/json',
|
||||||
|
'content-type': 'application/json',
|
||||||
|
Authorization: `Bearer ${readSecret(SECRET_KEYS.AI21)}`,
|
||||||
|
},
|
||||||
|
body: JSON.stringify({ text: req.body[0].content }),
|
||||||
|
};
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch('https://api.ai21.com/studio/v1/tokenize', options);
|
||||||
|
const data = await response.json();
|
||||||
|
return res.send({ 'token_count': data?.tokens?.length || 0 });
|
||||||
|
} catch (err) {
|
||||||
|
console.error(err);
|
||||||
|
return res.send({ 'token_count': 0 });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
router.post('/llama/encode', jsonParser, createSentencepieceEncodingHandler(spp_llama));
|
||||||
|
router.post('/nerdstash/encode', jsonParser, createSentencepieceEncodingHandler(spp_nerd));
|
||||||
|
router.post('/nerdstash_v2/encode', jsonParser, createSentencepieceEncodingHandler(spp_nerd_v2));
|
||||||
|
router.post('/mistral/encode', jsonParser, createSentencepieceEncodingHandler(spp_mistral));
|
||||||
|
router.post('/yi/encode', jsonParser, createSentencepieceEncodingHandler(spp_yi));
|
||||||
|
router.post('/gpt2/encode', jsonParser, createTiktokenEncodingHandler('gpt2'));
|
||||||
|
router.post('/llama/decode', jsonParser, createSentencepieceDecodingHandler(spp_llama));
|
||||||
|
router.post('/nerdstash/decode', jsonParser, createSentencepieceDecodingHandler(spp_nerd));
|
||||||
|
router.post('/nerdstash_v2/decode', jsonParser, createSentencepieceDecodingHandler(spp_nerd_v2));
|
||||||
|
router.post('/mistral/decode', jsonParser, createSentencepieceDecodingHandler(spp_mistral));
|
||||||
|
router.post('/yi/decode', jsonParser, createSentencepieceDecodingHandler(spp_yi));
|
||||||
|
router.post('/gpt2/decode', jsonParser, createTiktokenDecodingHandler('gpt2'));
|
||||||
|
|
||||||
|
router.post('/openai/encode', jsonParser, async function (req, res) {
|
||||||
|
try {
|
||||||
|
const queryModel = String(req.query.model || '');
|
||||||
|
|
||||||
|
if (queryModel.includes('llama')) {
|
||||||
|
const handler = createSentencepieceEncodingHandler(spp_llama);
|
||||||
|
return handler(req, res);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (queryModel.includes('mistral')) {
|
||||||
|
const handler = createSentencepieceEncodingHandler(spp_mistral);
|
||||||
|
return handler(req, res);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (queryModel.includes('yi')) {
|
||||||
|
const handler = createSentencepieceEncodingHandler(spp_yi);
|
||||||
|
return handler(req, res);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (queryModel.includes('claude')) {
|
||||||
|
const text = req.body.text || '';
|
||||||
|
const tokens = Object.values(claude_tokenizer.encode(text));
|
||||||
|
const chunks = await getWebTokenizersChunks(claude_tokenizer, tokens);
|
||||||
|
return res.send({ ids: tokens, count: tokens.length, chunks });
|
||||||
|
}
|
||||||
|
|
||||||
|
const model = getTokenizerModel(queryModel);
|
||||||
|
const handler = createTiktokenEncodingHandler(model);
|
||||||
|
return handler(req, res);
|
||||||
|
} catch (error) {
|
||||||
|
console.log(error);
|
||||||
|
return res.send({ ids: [], count: 0, chunks: [] });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
router.post('/openai/decode', jsonParser, async function (req, res) {
|
||||||
|
try {
|
||||||
|
const queryModel = String(req.query.model || '');
|
||||||
|
|
||||||
|
if (queryModel.includes('llama')) {
|
||||||
|
const handler = createSentencepieceDecodingHandler(spp_llama);
|
||||||
|
return handler(req, res);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (queryModel.includes('mistral')) {
|
||||||
|
const handler = createSentencepieceDecodingHandler(spp_mistral);
|
||||||
|
return handler(req, res);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (queryModel.includes('yi')) {
|
||||||
|
const handler = createSentencepieceDecodingHandler(spp_yi);
|
||||||
|
return handler(req, res);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (queryModel.includes('claude')) {
|
||||||
|
const ids = req.body.ids || [];
|
||||||
|
const chunkText = await claude_tokenizer.decode(new Uint32Array(ids));
|
||||||
|
return res.send({ text: chunkText });
|
||||||
|
}
|
||||||
|
|
||||||
|
const model = getTokenizerModel(queryModel);
|
||||||
|
const handler = createTiktokenDecodingHandler(model);
|
||||||
|
return handler(req, res);
|
||||||
|
} catch (error) {
|
||||||
|
console.log(error);
|
||||||
|
return res.send({ text: '' });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
router.post('/openai/count', jsonParser, async function (req, res) {
|
||||||
|
try {
|
||||||
if (!req.body) return res.sendStatus(400);
|
if (!req.body) return res.sendStatus(400);
|
||||||
const options = {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
accept: 'application/json',
|
|
||||||
'content-type': 'application/json',
|
|
||||||
Authorization: `Bearer ${readSecret(SECRET_KEYS.AI21)}`,
|
|
||||||
},
|
|
||||||
body: JSON.stringify({ text: req.body[0].content }),
|
|
||||||
};
|
|
||||||
|
|
||||||
try {
|
let num_tokens = 0;
|
||||||
const response = await fetch('https://api.ai21.com/studio/v1/tokenize', options);
|
const queryModel = String(req.query.model || '');
|
||||||
const data = await response.json();
|
const model = getTokenizerModel(queryModel);
|
||||||
return res.send({ 'token_count': data?.tokens?.length || 0 });
|
|
||||||
} catch (err) {
|
if (model === 'claude') {
|
||||||
console.error(err);
|
num_tokens = countClaudeTokens(claude_tokenizer, req.body);
|
||||||
return res.send({ 'token_count': 0 });
|
return res.send({ 'token_count': num_tokens });
|
||||||
}
|
}
|
||||||
});
|
|
||||||
|
|
||||||
app.post('/api/tokenizers/llama/encode', jsonParser, createSentencepieceEncodingHandler(spp_llama));
|
if (model === 'llama') {
|
||||||
app.post('/api/tokenizers/nerdstash/encode', jsonParser, createSentencepieceEncodingHandler(spp_nerd));
|
num_tokens = await countSentencepieceArrayTokens(spp_llama, req.body);
|
||||||
app.post('/api/tokenizers/nerdstash_v2/encode', jsonParser, createSentencepieceEncodingHandler(spp_nerd_v2));
|
return res.send({ 'token_count': num_tokens });
|
||||||
app.post('/api/tokenizers/mistral/encode', jsonParser, createSentencepieceEncodingHandler(spp_mistral));
|
|
||||||
app.post('/api/tokenizers/yi/encode', jsonParser, createSentencepieceEncodingHandler(spp_yi));
|
|
||||||
app.post('/api/tokenizers/gpt2/encode', jsonParser, createTiktokenEncodingHandler('gpt2'));
|
|
||||||
app.post('/api/tokenizers/llama/decode', jsonParser, createSentencepieceDecodingHandler(spp_llama));
|
|
||||||
app.post('/api/tokenizers/nerdstash/decode', jsonParser, createSentencepieceDecodingHandler(spp_nerd));
|
|
||||||
app.post('/api/tokenizers/nerdstash_v2/decode', jsonParser, createSentencepieceDecodingHandler(spp_nerd_v2));
|
|
||||||
app.post('/api/tokenizers/mistral/decode', jsonParser, createSentencepieceDecodingHandler(spp_mistral));
|
|
||||||
app.post('/api/tokenizers/yi/decode', jsonParser, createSentencepieceDecodingHandler(spp_yi));
|
|
||||||
app.post('/api/tokenizers/gpt2/decode', jsonParser, createTiktokenDecodingHandler('gpt2'));
|
|
||||||
|
|
||||||
app.post('/api/tokenizers/openai/encode', jsonParser, async function (req, res) {
|
|
||||||
try {
|
|
||||||
const queryModel = String(req.query.model || '');
|
|
||||||
|
|
||||||
if (queryModel.includes('llama')) {
|
|
||||||
const handler = createSentencepieceEncodingHandler(spp_llama);
|
|
||||||
return handler(req, res);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (queryModel.includes('mistral')) {
|
|
||||||
const handler = createSentencepieceEncodingHandler(spp_mistral);
|
|
||||||
return handler(req, res);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (queryModel.includes('yi')) {
|
|
||||||
const handler = createSentencepieceEncodingHandler(spp_yi);
|
|
||||||
return handler(req, res);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (queryModel.includes('claude')) {
|
|
||||||
const text = req.body.text || '';
|
|
||||||
const tokens = Object.values(claude_tokenizer.encode(text));
|
|
||||||
const chunks = await getWebTokenizersChunks(claude_tokenizer, tokens);
|
|
||||||
return res.send({ ids: tokens, count: tokens.length, chunks });
|
|
||||||
}
|
|
||||||
|
|
||||||
const model = getTokenizerModel(queryModel);
|
|
||||||
const handler = createTiktokenEncodingHandler(model);
|
|
||||||
return handler(req, res);
|
|
||||||
} catch (error) {
|
|
||||||
console.log(error);
|
|
||||||
return res.send({ ids: [], count: 0, chunks: [] });
|
|
||||||
}
|
}
|
||||||
});
|
|
||||||
|
|
||||||
app.post('/api/tokenizers/openai/decode', jsonParser, async function (req, res) {
|
if (model === 'mistral') {
|
||||||
try {
|
num_tokens = await countSentencepieceArrayTokens(spp_mistral, req.body);
|
||||||
const queryModel = String(req.query.model || '');
|
return res.send({ 'token_count': num_tokens });
|
||||||
|
|
||||||
if (queryModel.includes('llama')) {
|
|
||||||
const handler = createSentencepieceDecodingHandler(spp_llama);
|
|
||||||
return handler(req, res);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (queryModel.includes('mistral')) {
|
|
||||||
const handler = createSentencepieceDecodingHandler(spp_mistral);
|
|
||||||
return handler(req, res);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (queryModel.includes('yi')) {
|
|
||||||
const handler = createSentencepieceDecodingHandler(spp_yi);
|
|
||||||
return handler(req, res);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (queryModel.includes('claude')) {
|
|
||||||
const ids = req.body.ids || [];
|
|
||||||
const chunkText = await claude_tokenizer.decode(new Uint32Array(ids));
|
|
||||||
return res.send({ text: chunkText });
|
|
||||||
}
|
|
||||||
|
|
||||||
const model = getTokenizerModel(queryModel);
|
|
||||||
const handler = createTiktokenDecodingHandler(model);
|
|
||||||
return handler(req, res);
|
|
||||||
} catch (error) {
|
|
||||||
console.log(error);
|
|
||||||
return res.send({ text: '' });
|
|
||||||
}
|
}
|
||||||
});
|
|
||||||
|
|
||||||
app.post('/api/tokenizers/openai/count', jsonParser, async function (req, res) {
|
if (model === 'yi') {
|
||||||
try {
|
num_tokens = await countSentencepieceArrayTokens(spp_yi, req.body);
|
||||||
if (!req.body) return res.sendStatus(400);
|
return res.send({ 'token_count': num_tokens });
|
||||||
|
}
|
||||||
|
|
||||||
let num_tokens = 0;
|
const tokensPerName = queryModel.includes('gpt-3.5-turbo-0301') ? -1 : 1;
|
||||||
const queryModel = String(req.query.model || '');
|
const tokensPerMessage = queryModel.includes('gpt-3.5-turbo-0301') ? 4 : 3;
|
||||||
const model = getTokenizerModel(queryModel);
|
const tokensPadding = 3;
|
||||||
|
|
||||||
if (model === 'claude') {
|
const tokenizer = getTiktokenTokenizer(model);
|
||||||
num_tokens = countClaudeTokens(claude_tokenizer, req.body);
|
|
||||||
return res.send({ 'token_count': num_tokens });
|
|
||||||
}
|
|
||||||
|
|
||||||
if (model === 'llama') {
|
for (const msg of req.body) {
|
||||||
num_tokens = await countSentencepieceArrayTokens(spp_llama, req.body);
|
try {
|
||||||
return res.send({ 'token_count': num_tokens });
|
num_tokens += tokensPerMessage;
|
||||||
}
|
for (const [key, value] of Object.entries(msg)) {
|
||||||
|
num_tokens += tokenizer.encode(value).length;
|
||||||
if (model === 'mistral') {
|
if (key == 'name') {
|
||||||
num_tokens = await countSentencepieceArrayTokens(spp_mistral, req.body);
|
num_tokens += tokensPerName;
|
||||||
return res.send({ 'token_count': num_tokens });
|
|
||||||
}
|
|
||||||
|
|
||||||
if (model === 'yi') {
|
|
||||||
num_tokens = await countSentencepieceArrayTokens(spp_yi, req.body);
|
|
||||||
return res.send({ 'token_count': num_tokens });
|
|
||||||
}
|
|
||||||
|
|
||||||
const tokensPerName = queryModel.includes('gpt-3.5-turbo-0301') ? -1 : 1;
|
|
||||||
const tokensPerMessage = queryModel.includes('gpt-3.5-turbo-0301') ? 4 : 3;
|
|
||||||
const tokensPadding = 3;
|
|
||||||
|
|
||||||
const tokenizer = getTiktokenTokenizer(model);
|
|
||||||
|
|
||||||
for (const msg of req.body) {
|
|
||||||
try {
|
|
||||||
num_tokens += tokensPerMessage;
|
|
||||||
for (const [key, value] of Object.entries(msg)) {
|
|
||||||
num_tokens += tokenizer.encode(value).length;
|
|
||||||
if (key == 'name') {
|
|
||||||
num_tokens += tokensPerName;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} catch {
|
|
||||||
console.warn('Error tokenizing message:', msg);
|
|
||||||
}
|
}
|
||||||
|
} catch {
|
||||||
|
console.warn('Error tokenizing message:', msg);
|
||||||
}
|
}
|
||||||
num_tokens += tokensPadding;
|
|
||||||
|
|
||||||
// NB: Since 2023-10-14, the GPT-3.5 Turbo 0301 model shoves in 7-9 extra tokens to every message.
|
|
||||||
// More details: https://community.openai.com/t/gpt-3-5-turbo-0301-showing-different-behavior-suddenly/431326/14
|
|
||||||
if (queryModel.includes('gpt-3.5-turbo-0301')) {
|
|
||||||
num_tokens += 9;
|
|
||||||
}
|
|
||||||
|
|
||||||
// not needed for cached tokenizers
|
|
||||||
//tokenizer.free();
|
|
||||||
|
|
||||||
res.send({ 'token_count': num_tokens });
|
|
||||||
} catch (error) {
|
|
||||||
console.error('An error counting tokens, using fallback estimation method', error);
|
|
||||||
const jsonBody = JSON.stringify(req.body);
|
|
||||||
const num_tokens = Math.ceil(jsonBody.length / CHARS_PER_TOKEN);
|
|
||||||
res.send({ 'token_count': num_tokens });
|
|
||||||
}
|
}
|
||||||
});
|
num_tokens += tokensPadding;
|
||||||
}
|
|
||||||
|
// NB: Since 2023-10-14, the GPT-3.5 Turbo 0301 model shoves in 7-9 extra tokens to every message.
|
||||||
|
// More details: https://community.openai.com/t/gpt-3-5-turbo-0301-showing-different-behavior-suddenly/431326/14
|
||||||
|
if (queryModel.includes('gpt-3.5-turbo-0301')) {
|
||||||
|
num_tokens += 9;
|
||||||
|
}
|
||||||
|
|
||||||
|
// not needed for cached tokenizers
|
||||||
|
//tokenizer.free();
|
||||||
|
|
||||||
|
res.send({ 'token_count': num_tokens });
|
||||||
|
} catch (error) {
|
||||||
|
console.error('An error counting tokens, using fallback estimation method', error);
|
||||||
|
const jsonBody = JSON.stringify(req.body);
|
||||||
|
const num_tokens = Math.ceil(jsonBody.length / CHARS_PER_TOKEN);
|
||||||
|
res.send({ 'token_count': num_tokens });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
TEXT_COMPLETION_MODELS,
|
TEXT_COMPLETION_MODELS,
|
||||||
@ -543,8 +540,7 @@ module.exports = {
|
|||||||
getTiktokenTokenizer,
|
getTiktokenTokenizer,
|
||||||
countClaudeTokens,
|
countClaudeTokens,
|
||||||
loadTokenizers,
|
loadTokenizers,
|
||||||
registerEndpoints,
|
|
||||||
getSentencepiceTokenizer,
|
getSentencepiceTokenizer,
|
||||||
sentencepieceTokenizers,
|
sentencepieceTokenizers,
|
||||||
|
router,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user