Add llama 3 tokenizer

This commit is contained in:
Cohee
2024-05-03 23:59:39 +03:00
parent 7bc87b6e28
commit 7bfd666321
8 changed files with 143 additions and 21 deletions

View File

@ -142,6 +142,7 @@ const spp_nerd_v2 = new SentencePieceTokenizer('src/tokenizers/nerdstash_v2.mode
const spp_mistral = new SentencePieceTokenizer('src/tokenizers/mistral.model');
const spp_yi = new SentencePieceTokenizer('src/tokenizers/yi.model');
const claude_tokenizer = new WebTokenizer('src/tokenizers/claude.json');
const llama3_tokenizer = new WebTokenizer('src/tokenizers/llama3.json');
const sentencepieceTokenizers = [
'llama',
@ -285,6 +286,10 @@ function getTokenizerModel(requestModel) {
return 'claude';
}
if (requestModel.includes('llama3') || requestModel.includes('llama-3')) {
return 'llama3';
}
if (requestModel.includes('llama')) {
return 'llama';
}
@ -313,12 +318,12 @@ function getTiktokenTokenizer(model) {
}
/**
* Counts the tokens for the given messages using the Claude tokenizer.
* Counts the tokens for the given messages using the WebTokenizer and Claude prompt conversion.
* @param {Tokenizer} tokenizer Web tokenizer
* @param {object[]} messages Array of messages
* @returns {number} Number of tokens
*/
function countClaudeTokens(tokenizer, messages) {
function countWebTokenizerTokens(tokenizer, messages) {
// Should be fine if we use the old conversion method instead of the messages API one i think?
const convertedPrompt = convertClaudePrompt(messages, false, '', false, false, '', false);
@ -449,6 +454,67 @@ function createTiktokenDecodingHandler(modelId) {
};
}
/**
* Creates an API handler for encoding WebTokenizer tokens.
* @param {WebTokenizer} tokenizer WebTokenizer instance
* @returns {TokenizationHandler} Handler function
*/
function createWebTokenizerEncodingHandler(tokenizer) {
/**
* Request handler for encoding WebTokenizer tokens.
* @param {import('express').Request} request
* @param {import('express').Response} response
*/
return async function (request, response) {
try {
if (!request.body) {
return response.sendStatus(400);
}
const text = request.body.text || '';
const instance = await tokenizer?.get();
if (!instance) throw new Error('Failed to load the Web tokenizer');
const tokens = Array.from(instance.encode(text));
const chunks = getWebTokenizersChunks(instance, tokens);
return response.send({ ids: tokens, count: tokens.length, chunks });
} catch (error) {
console.log(error);
return response.send({ ids: [], count: 0, chunks: [] });
}
};
}
/**
* Creates an API handler for decoding WebTokenizer tokens.
* @param {WebTokenizer} tokenizer WebTokenizer instance
* @returns {TokenizationHandler} Handler function
*/
function createWebTokenizerDecodingHandler(tokenizer) {
/**
* Request handler for decoding WebTokenizer tokens.
* @param {import('express').Request} request
* @param {import('express').Response} response
* @returns {Promise<any>}
*/
return async function (request, response) {
try {
if (!request.body) {
return response.sendStatus(400);
}
const ids = request.body.ids || [];
const instance = await tokenizer?.get();
if (!instance) throw new Error('Failed to load the Web tokenizer');
const chunks = getWebTokenizersChunks(instance, ids);
const text = instance.decode(new Int32Array(ids));
return response.send({ text, chunks });
} catch (error) {
console.log(error);
return response.send({ text: '', chunks: [] });
}
};
}
const router = express.Router();
router.post('/ai21/count', jsonParser, async function (req, res) {
@ -501,17 +567,26 @@ router.post('/nerdstash_v2/encode', jsonParser, createSentencepieceEncodingHandl
router.post('/mistral/encode', jsonParser, createSentencepieceEncodingHandler(spp_mistral));
router.post('/yi/encode', jsonParser, createSentencepieceEncodingHandler(spp_yi));
router.post('/gpt2/encode', jsonParser, createTiktokenEncodingHandler('gpt2'));
router.post('/claude/encode', jsonParser, createWebTokenizerEncodingHandler(claude_tokenizer));
router.post('/llama3/encode', jsonParser, createWebTokenizerEncodingHandler(llama3_tokenizer));
router.post('/llama/decode', jsonParser, createSentencepieceDecodingHandler(spp_llama));
router.post('/nerdstash/decode', jsonParser, createSentencepieceDecodingHandler(spp_nerd));
router.post('/nerdstash_v2/decode', jsonParser, createSentencepieceDecodingHandler(spp_nerd_v2));
router.post('/mistral/decode', jsonParser, createSentencepieceDecodingHandler(spp_mistral));
router.post('/yi/decode', jsonParser, createSentencepieceDecodingHandler(spp_yi));
router.post('/gpt2/decode', jsonParser, createTiktokenDecodingHandler('gpt2'));
router.post('/claude/decode', jsonParser, createWebTokenizerDecodingHandler(claude_tokenizer));
router.post('/llama3/decode', jsonParser, createWebTokenizerDecodingHandler(llama3_tokenizer));
router.post('/openai/encode', jsonParser, async function (req, res) {
try {
const queryModel = String(req.query.model || '');
if (queryModel.includes('llama3') || queryModel.includes('llama-3')) {
const handler = createWebTokenizerEncodingHandler(llama3_tokenizer);
return handler(req, res);
}
if (queryModel.includes('llama')) {
const handler = createSentencepieceEncodingHandler(spp_llama);
return handler(req, res);
@ -528,12 +603,8 @@ router.post('/openai/encode', jsonParser, async function (req, res) {
}
if (queryModel.includes('claude')) {
const text = req.body.text || '';
const instance = await claude_tokenizer.get();
if (!instance) throw new Error('Failed to load the Claude tokenizer');
const tokens = Object.values(instance.encode(text));
const chunks = getWebTokenizersChunks(instance, tokens);
return res.send({ ids: tokens, count: tokens.length, chunks });
const handler = createWebTokenizerEncodingHandler(claude_tokenizer);
return handler(req, res);
}
const model = getTokenizerModel(queryModel);
@ -549,6 +620,11 @@ router.post('/openai/decode', jsonParser, async function (req, res) {
try {
const queryModel = String(req.query.model || '');
if (queryModel.includes('llama3') || queryModel.includes('llama-3')) {
const handler = createWebTokenizerDecodingHandler(llama3_tokenizer);
return handler(req, res);
}
if (queryModel.includes('llama')) {
const handler = createSentencepieceDecodingHandler(spp_llama);
return handler(req, res);
@ -565,11 +641,8 @@ router.post('/openai/decode', jsonParser, async function (req, res) {
}
if (queryModel.includes('claude')) {
const ids = req.body.ids || [];
const instance = await claude_tokenizer.get();
if (!instance) throw new Error('Failed to load the Claude tokenizer');
const chunkText = instance.decode(new Int32Array(ids));
return res.send({ text: chunkText });
const handler = createWebTokenizerDecodingHandler(claude_tokenizer);
return handler(req, res);
}
const model = getTokenizerModel(queryModel);
@ -592,7 +665,14 @@ router.post('/openai/count', jsonParser, async function (req, res) {
if (model === 'claude') {
const instance = await claude_tokenizer.get();
if (!instance) throw new Error('Failed to load the Claude tokenizer');
num_tokens = countClaudeTokens(instance, req.body);
num_tokens = countWebTokenizerTokens(instance, req.body);
return res.send({ 'token_count': num_tokens });
}
if (model === 'llama3' || model === 'llama-3') {
const instance = await llama3_tokenizer.get();
if (!instance) throw new Error('Failed to load the Llama3 tokenizer');
num_tokens = countWebTokenizerTokens(instance, req.body);
return res.send({ 'token_count': num_tokens });
}
@ -755,7 +835,7 @@ module.exports = {
TEXT_COMPLETION_MODELS,
getTokenizerModel,
getTiktokenTokenizer,
countClaudeTokens,
countWebTokenizerTokens,
getSentencepiceTokenizer,
sentencepieceTokenizers,
router,