New AI21 Jamba + tokenizer

This commit is contained in:
Cohee
2024-08-26 12:07:36 +03:00
parent ff834efde3
commit 5fc16a2474
10 changed files with 188 additions and 266 deletions

View File

@ -144,6 +144,7 @@ const spp_nerd_v2 = new SentencePieceTokenizer('src/tokenizers/nerdstash_v2.mode
const spp_mistral = new SentencePieceTokenizer('src/tokenizers/mistral.model');
const spp_yi = new SentencePieceTokenizer('src/tokenizers/yi.model');
const spp_gemma = new SentencePieceTokenizer('src/tokenizers/gemma.model');
const spp_jamba = new SentencePieceTokenizer('src/tokenizers/jamba.model');
const claude_tokenizer = new WebTokenizer('src/tokenizers/claude.json');
const llama3_tokenizer = new WebTokenizer('src/tokenizers/llama3.json');
@ -154,6 +155,7 @@ const sentencepieceTokenizers = [
'mistral',
'yi',
'gemma',
'jamba',
];
/**
@ -186,6 +188,10 @@ function getSentencepiceTokenizer(model) {
return spp_gemma;
}
if (model.includes('jamba')) {
return spp_jamba;
}
return null;
}
@ -322,6 +328,10 @@ function getTokenizerModel(requestModel) {
return 'gemma';
}
if (requestModel.includes('jamba')) {
return 'jamba';
}
// default
return 'gpt-3.5-turbo';
}
@ -537,59 +547,13 @@ function createWebTokenizerDecodingHandler(tokenizer) {
const router = express.Router();
router.post('/ai21/count', jsonParser, async function (req, res) {
if (!req.body) return res.sendStatus(400);
const key = readSecret(req.user.directories, SECRET_KEYS.AI21);
const options = {
method: 'POST',
headers: {
accept: 'application/json',
'content-type': 'application/json',
Authorization: `Bearer ${key}`,
},
body: JSON.stringify({ text: req.body[0].content }),
};
try {
const response = await fetch('https://api.ai21.com/studio/v1/tokenize', options);
const data = await response.json();
return res.send({ 'token_count': data?.tokens?.length || 0 });
} catch (err) {
console.error(err);
return res.send({ 'token_count': 0 });
}
});
router.post('/google/count', jsonParser, async function (req, res) {
if (!req.body) return res.sendStatus(400);
const options = {
method: 'POST',
headers: {
accept: 'application/json',
'content-type': 'application/json',
},
body: JSON.stringify({ contents: convertGooglePrompt(req.body, String(req.query.model)).contents }),
};
try {
const reverseProxy = req.query.reverse_proxy?.toString() || '';
const proxyPassword = req.query.proxy_password?.toString() || '';
const apiKey = reverseProxy ? proxyPassword : readSecret(req.user.directories, SECRET_KEYS.MAKERSUITE);
const apiUrl = new URL(reverseProxy || API_MAKERSUITE);
const response = await fetch(`${apiUrl.origin}/v1beta/models/${req.query.model}:countTokens?key=${apiKey}`, options);
const data = await response.json();
return res.send({ 'token_count': data?.totalTokens || 0 });
} catch (err) {
console.error(err);
return res.send({ 'token_count': 0 });
}
});
router.post('/llama/encode', jsonParser, createSentencepieceEncodingHandler(spp_llama));
router.post('/nerdstash/encode', jsonParser, createSentencepieceEncodingHandler(spp_nerd));
router.post('/nerdstash_v2/encode', jsonParser, createSentencepieceEncodingHandler(spp_nerd_v2));
router.post('/mistral/encode', jsonParser, createSentencepieceEncodingHandler(spp_mistral));
router.post('/yi/encode', jsonParser, createSentencepieceEncodingHandler(spp_yi));
router.post('/gemma/encode', jsonParser, createSentencepieceEncodingHandler(spp_gemma));
router.post('/jamba/encode', jsonParser, createSentencepieceEncodingHandler(spp_jamba));
router.post('/gpt2/encode', jsonParser, createTiktokenEncodingHandler('gpt2'));
router.post('/claude/encode', jsonParser, createWebTokenizerEncodingHandler(claude_tokenizer));
router.post('/llama3/encode', jsonParser, createWebTokenizerEncodingHandler(llama3_tokenizer));
@ -599,6 +563,7 @@ router.post('/nerdstash_v2/decode', jsonParser, createSentencepieceDecodingHandl
router.post('/mistral/decode', jsonParser, createSentencepieceDecodingHandler(spp_mistral));
router.post('/yi/decode', jsonParser, createSentencepieceDecodingHandler(spp_yi));
router.post('/gemma/decode', jsonParser, createSentencepieceDecodingHandler(spp_gemma));
router.post('/jamba/decode', jsonParser, createSentencepieceDecodingHandler(spp_jamba));
router.post('/gpt2/decode', jsonParser, createTiktokenDecodingHandler('gpt2'));
router.post('/claude/decode', jsonParser, createWebTokenizerDecodingHandler(claude_tokenizer));
router.post('/llama3/decode', jsonParser, createWebTokenizerDecodingHandler(llama3_tokenizer));
@ -637,6 +602,11 @@ router.post('/openai/encode', jsonParser, async function (req, res) {
return handler(req, res);
}
if (queryModel.includes('jamba')) {
const handler = createSentencepieceEncodingHandler(spp_jamba);
return handler(req, res);
}
const model = getTokenizerModel(queryModel);
const handler = createTiktokenEncodingHandler(model);
return handler(req, res);
@ -680,6 +650,11 @@ router.post('/openai/decode', jsonParser, async function (req, res) {
return handler(req, res);
}
if (queryModel.includes('jamba')) {
const handler = createSentencepieceDecodingHandler(spp_jamba);
return handler(req, res);
}
const model = getTokenizerModel(queryModel);
const handler = createTiktokenDecodingHandler(model);
return handler(req, res);
@ -731,6 +706,11 @@ router.post('/openai/count', jsonParser, async function (req, res) {
return res.send({ 'token_count': num_tokens });
}
if (model === 'jamba') {
num_tokens = await countSentencepieceArrayTokens(spp_jamba, req.body);
return res.send({ 'token_count': num_tokens });
}
const tokensPerName = queryModel.includes('gpt-3.5-turbo-0301') ? -1 : 1;
const tokensPerMessage = queryModel.includes('gpt-3.5-turbo-0301') ? 4 : 3;
const tokensPadding = 3;