mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-06-05 21:59:27 +02:00
Add llama 3 tokenizer
This commit is contained in:
@ -142,6 +142,7 @@ const spp_nerd_v2 = new SentencePieceTokenizer('src/tokenizers/nerdstash_v2.mode
|
||||
const spp_mistral = new SentencePieceTokenizer('src/tokenizers/mistral.model');
|
||||
const spp_yi = new SentencePieceTokenizer('src/tokenizers/yi.model');
|
||||
const claude_tokenizer = new WebTokenizer('src/tokenizers/claude.json');
|
||||
const llama3_tokenizer = new WebTokenizer('src/tokenizers/llama3.json');
|
||||
|
||||
const sentencepieceTokenizers = [
|
||||
'llama',
|
||||
@ -285,6 +286,10 @@ function getTokenizerModel(requestModel) {
|
||||
return 'claude';
|
||||
}
|
||||
|
||||
if (requestModel.includes('llama3') || requestModel.includes('llama-3')) {
|
||||
return 'llama3';
|
||||
}
|
||||
|
||||
if (requestModel.includes('llama')) {
|
||||
return 'llama';
|
||||
}
|
||||
@ -313,12 +318,12 @@ function getTiktokenTokenizer(model) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Counts the tokens for the given messages using the Claude tokenizer.
|
||||
* Counts the tokens for the given messages using the WebTokenizer and Claude prompt conversion.
|
||||
* @param {Tokenizer} tokenizer Web tokenizer
|
||||
* @param {object[]} messages Array of messages
|
||||
* @returns {number} Number of tokens
|
||||
*/
|
||||
function countClaudeTokens(tokenizer, messages) {
|
||||
function countWebTokenizerTokens(tokenizer, messages) {
|
||||
// Should be fine if we use the old conversion method instead of the messages API one i think?
|
||||
const convertedPrompt = convertClaudePrompt(messages, false, '', false, false, '', false);
|
||||
|
||||
@ -449,6 +454,67 @@ function createTiktokenDecodingHandler(modelId) {
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an API handler for encoding WebTokenizer tokens.
|
||||
* @param {WebTokenizer} tokenizer WebTokenizer instance
|
||||
* @returns {TokenizationHandler} Handler function
|
||||
*/
|
||||
function createWebTokenizerEncodingHandler(tokenizer) {
|
||||
/**
|
||||
* Request handler for encoding WebTokenizer tokens.
|
||||
* @param {import('express').Request} request
|
||||
* @param {import('express').Response} response
|
||||
*/
|
||||
return async function (request, response) {
|
||||
try {
|
||||
if (!request.body) {
|
||||
return response.sendStatus(400);
|
||||
}
|
||||
|
||||
const text = request.body.text || '';
|
||||
const instance = await tokenizer?.get();
|
||||
if (!instance) throw new Error('Failed to load the Web tokenizer');
|
||||
const tokens = Array.from(instance.encode(text));
|
||||
const chunks = getWebTokenizersChunks(instance, tokens);
|
||||
return response.send({ ids: tokens, count: tokens.length, chunks });
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
return response.send({ ids: [], count: 0, chunks: [] });
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an API handler for decoding WebTokenizer tokens.
|
||||
* @param {WebTokenizer} tokenizer WebTokenizer instance
|
||||
* @returns {TokenizationHandler} Handler function
|
||||
*/
|
||||
function createWebTokenizerDecodingHandler(tokenizer) {
|
||||
/**
|
||||
* Request handler for decoding WebTokenizer tokens.
|
||||
* @param {import('express').Request} request
|
||||
* @param {import('express').Response} response
|
||||
* @returns {Promise<any>}
|
||||
*/
|
||||
return async function (request, response) {
|
||||
try {
|
||||
if (!request.body) {
|
||||
return response.sendStatus(400);
|
||||
}
|
||||
|
||||
const ids = request.body.ids || [];
|
||||
const instance = await tokenizer?.get();
|
||||
if (!instance) throw new Error('Failed to load the Web tokenizer');
|
||||
const chunks = getWebTokenizersChunks(instance, ids);
|
||||
const text = instance.decode(new Int32Array(ids));
|
||||
return response.send({ text, chunks });
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
return response.send({ text: '', chunks: [] });
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.post('/ai21/count', jsonParser, async function (req, res) {
|
||||
@ -501,17 +567,26 @@ router.post('/nerdstash_v2/encode', jsonParser, createSentencepieceEncodingHandl
|
||||
router.post('/mistral/encode', jsonParser, createSentencepieceEncodingHandler(spp_mistral));
|
||||
router.post('/yi/encode', jsonParser, createSentencepieceEncodingHandler(spp_yi));
|
||||
router.post('/gpt2/encode', jsonParser, createTiktokenEncodingHandler('gpt2'));
|
||||
router.post('/claude/encode', jsonParser, createWebTokenizerEncodingHandler(claude_tokenizer));
|
||||
router.post('/llama3/encode', jsonParser, createWebTokenizerEncodingHandler(llama3_tokenizer));
|
||||
router.post('/llama/decode', jsonParser, createSentencepieceDecodingHandler(spp_llama));
|
||||
router.post('/nerdstash/decode', jsonParser, createSentencepieceDecodingHandler(spp_nerd));
|
||||
router.post('/nerdstash_v2/decode', jsonParser, createSentencepieceDecodingHandler(spp_nerd_v2));
|
||||
router.post('/mistral/decode', jsonParser, createSentencepieceDecodingHandler(spp_mistral));
|
||||
router.post('/yi/decode', jsonParser, createSentencepieceDecodingHandler(spp_yi));
|
||||
router.post('/gpt2/decode', jsonParser, createTiktokenDecodingHandler('gpt2'));
|
||||
router.post('/claude/decode', jsonParser, createWebTokenizerDecodingHandler(claude_tokenizer));
|
||||
router.post('/llama3/decode', jsonParser, createWebTokenizerDecodingHandler(llama3_tokenizer));
|
||||
|
||||
router.post('/openai/encode', jsonParser, async function (req, res) {
|
||||
try {
|
||||
const queryModel = String(req.query.model || '');
|
||||
|
||||
if (queryModel.includes('llama3') || queryModel.includes('llama-3')) {
|
||||
const handler = createWebTokenizerEncodingHandler(llama3_tokenizer);
|
||||
return handler(req, res);
|
||||
}
|
||||
|
||||
if (queryModel.includes('llama')) {
|
||||
const handler = createSentencepieceEncodingHandler(spp_llama);
|
||||
return handler(req, res);
|
||||
@ -528,12 +603,8 @@ router.post('/openai/encode', jsonParser, async function (req, res) {
|
||||
}
|
||||
|
||||
if (queryModel.includes('claude')) {
|
||||
const text = req.body.text || '';
|
||||
const instance = await claude_tokenizer.get();
|
||||
if (!instance) throw new Error('Failed to load the Claude tokenizer');
|
||||
const tokens = Object.values(instance.encode(text));
|
||||
const chunks = getWebTokenizersChunks(instance, tokens);
|
||||
return res.send({ ids: tokens, count: tokens.length, chunks });
|
||||
const handler = createWebTokenizerEncodingHandler(claude_tokenizer);
|
||||
return handler(req, res);
|
||||
}
|
||||
|
||||
const model = getTokenizerModel(queryModel);
|
||||
@ -549,6 +620,11 @@ router.post('/openai/decode', jsonParser, async function (req, res) {
|
||||
try {
|
||||
const queryModel = String(req.query.model || '');
|
||||
|
||||
if (queryModel.includes('llama3') || queryModel.includes('llama-3')) {
|
||||
const handler = createWebTokenizerDecodingHandler(llama3_tokenizer);
|
||||
return handler(req, res);
|
||||
}
|
||||
|
||||
if (queryModel.includes('llama')) {
|
||||
const handler = createSentencepieceDecodingHandler(spp_llama);
|
||||
return handler(req, res);
|
||||
@ -565,11 +641,8 @@ router.post('/openai/decode', jsonParser, async function (req, res) {
|
||||
}
|
||||
|
||||
if (queryModel.includes('claude')) {
|
||||
const ids = req.body.ids || [];
|
||||
const instance = await claude_tokenizer.get();
|
||||
if (!instance) throw new Error('Failed to load the Claude tokenizer');
|
||||
const chunkText = instance.decode(new Int32Array(ids));
|
||||
return res.send({ text: chunkText });
|
||||
const handler = createWebTokenizerDecodingHandler(claude_tokenizer);
|
||||
return handler(req, res);
|
||||
}
|
||||
|
||||
const model = getTokenizerModel(queryModel);
|
||||
@ -592,7 +665,14 @@ router.post('/openai/count', jsonParser, async function (req, res) {
|
||||
if (model === 'claude') {
|
||||
const instance = await claude_tokenizer.get();
|
||||
if (!instance) throw new Error('Failed to load the Claude tokenizer');
|
||||
num_tokens = countClaudeTokens(instance, req.body);
|
||||
num_tokens = countWebTokenizerTokens(instance, req.body);
|
||||
return res.send({ 'token_count': num_tokens });
|
||||
}
|
||||
|
||||
if (model === 'llama3' || model === 'llama-3') {
|
||||
const instance = await llama3_tokenizer.get();
|
||||
if (!instance) throw new Error('Failed to load the Llama3 tokenizer');
|
||||
num_tokens = countWebTokenizerTokens(instance, req.body);
|
||||
return res.send({ 'token_count': num_tokens });
|
||||
}
|
||||
|
||||
@ -755,7 +835,7 @@ module.exports = {
|
||||
TEXT_COMPLETION_MODELS,
|
||||
getTokenizerModel,
|
||||
getTiktokenTokenizer,
|
||||
countClaudeTokens,
|
||||
countWebTokenizerTokens,
|
||||
getSentencepiceTokenizer,
|
||||
sentencepieceTokenizers,
|
||||
router,
|
||||
|
Reference in New Issue
Block a user