Add tokenizer for Gemma/Gemini

This commit is contained in:
Cohee
2024-08-16 00:00:43 +03:00
parent ce8b0aae96
commit e707def7dd
5 changed files with 52 additions and 3 deletions

View File

@ -143,6 +143,7 @@ const spp_nerd = new SentencePieceTokenizer('src/tokenizers/nerdstash.model');
const spp_nerd_v2 = new SentencePieceTokenizer('src/tokenizers/nerdstash_v2.model');
const spp_mistral = new SentencePieceTokenizer('src/tokenizers/mistral.model');
const spp_yi = new SentencePieceTokenizer('src/tokenizers/yi.model');
const spp_gemma = new SentencePieceTokenizer('src/tokenizers/gemma.model');
const claude_tokenizer = new WebTokenizer('src/tokenizers/claude.json');
const llama3_tokenizer = new WebTokenizer('src/tokenizers/llama3.json');
@ -152,6 +153,7 @@ const sentencepieceTokenizers = [
'nerdstash_v2',
'mistral',
'yi',
'gemma',
];
/**
@ -180,6 +182,10 @@ function getSentencepiceTokenizer(model) {
return spp_yi;
}
if (model.includes('gemma')) {
return spp_gemma;
}
return null;
}
@ -312,8 +318,8 @@ function getTokenizerModel(requestModel) {
return 'yi';
}
if (requestModel.includes('gemini')) {
return 'gpt-4o';
if (requestModel.includes('gemma') || requestModel.includes('gemini')) {
return 'gemma';
}
// default
@ -583,6 +589,7 @@ router.post('/nerdstash/encode', jsonParser, createSentencepieceEncodingHandler(
router.post('/nerdstash_v2/encode', jsonParser, createSentencepieceEncodingHandler(spp_nerd_v2));
router.post('/mistral/encode', jsonParser, createSentencepieceEncodingHandler(spp_mistral));
router.post('/yi/encode', jsonParser, createSentencepieceEncodingHandler(spp_yi));
router.post('/gemma/encode', jsonParser, createSentencepieceEncodingHandler(spp_gemma));
router.post('/gpt2/encode', jsonParser, createTiktokenEncodingHandler('gpt2'));
router.post('/claude/encode', jsonParser, createWebTokenizerEncodingHandler(claude_tokenizer));
router.post('/llama3/encode', jsonParser, createWebTokenizerEncodingHandler(llama3_tokenizer));
@ -591,6 +598,7 @@ router.post('/nerdstash/decode', jsonParser, createSentencepieceDecodingHandler(
router.post('/nerdstash_v2/decode', jsonParser, createSentencepieceDecodingHandler(spp_nerd_v2));
router.post('/mistral/decode', jsonParser, createSentencepieceDecodingHandler(spp_mistral));
router.post('/yi/decode', jsonParser, createSentencepieceDecodingHandler(spp_yi));
router.post('/gemma/decode', jsonParser, createSentencepieceDecodingHandler(spp_gemma));
router.post('/gpt2/decode', jsonParser, createTiktokenDecodingHandler('gpt2'));
router.post('/claude/decode', jsonParser, createWebTokenizerDecodingHandler(claude_tokenizer));
router.post('/llama3/decode', jsonParser, createWebTokenizerDecodingHandler(llama3_tokenizer));
@ -624,6 +632,11 @@ router.post('/openai/encode', jsonParser, async function (req, res) {
return handler(req, res);
}
if (queryModel.includes('gemma') || queryModel.includes('gemini')) {
const handler = createSentencepieceEncodingHandler(spp_gemma);
return handler(req, res);
}
const model = getTokenizerModel(queryModel);
const handler = createTiktokenEncodingHandler(model);
return handler(req, res);
@ -662,6 +675,11 @@ router.post('/openai/decode', jsonParser, async function (req, res) {
return handler(req, res);
}
if (queryModel.includes('gemma') || queryModel.includes('gemini')) {
const handler = createSentencepieceDecodingHandler(spp_gemma);
return handler(req, res);
}
const model = getTokenizerModel(queryModel);
const handler = createTiktokenDecodingHandler(model);
return handler(req, res);
@ -708,6 +726,11 @@ router.post('/openai/count', jsonParser, async function (req, res) {
return res.send({ 'token_count': num_tokens });
}
if (model === 'gemma' || model === 'gemini') {
num_tokens = await countSentencepieceArrayTokens(spp_gemma, req.body);
return res.send({ 'token_count': num_tokens });
}
const tokensPerName = queryModel.includes('gpt-3.5-turbo-0301') ? -1 : 1;
const tokensPerMessage = queryModel.includes('gpt-3.5-turbo-0301') ? 4 : 3;
const tokensPadding = 3;