diff --git a/public/index.html b/public/index.html index 8fa60544f..28d9d71c1 100644 --- a/public/index.html +++ b/public/index.html @@ -2384,9 +2384,10 @@ - + + @@ -5044,4 +5045,4 @@ - \ No newline at end of file + diff --git a/public/scripts/tokenizers.js b/public/scripts/tokenizers.js index 4d32b2050..e78786646 100644 --- a/public/scripts/tokenizers.js +++ b/public/scripts/tokenizers.js @@ -16,6 +16,7 @@ export const tokenizers = { NERD: 4, NERD2: 5, API: 6, + MISTRAL: 7, BEST_MATCH: 99, }; @@ -105,6 +106,8 @@ function callTokenizer(type, str, padding) { return countTokensRemote('/api/tokenize/nerdstash', str, padding); case tokenizers.NERD2: return countTokensRemote('/api/tokenize/nerdstash_v2', str, padding); + case tokenizers.MISTRAL: + return countTokensRemote('/api/tokenize/mistral', str, padding); case tokenizers.API: return countTokensRemote('/tokenize_via_api', str, padding); default: @@ -185,6 +188,7 @@ export function getTokenizerModel() { const gpt2Tokenizer = 'gpt2'; const claudeTokenizer = 'claude'; const llamaTokenizer = 'llama'; + const mistralTokenizer = 'mistral'; // Assuming no one would use it for different models.. right? if (oai_settings.chat_completion_source == chat_completion_sources.SCALE) { @@ -217,6 +221,9 @@ export function getTokenizerModel() { if (model?.architecture?.tokenizer === 'Llama2') { return llamaTokenizer; } + else if (model?.architecture?.tokenizer === 'Mistral') { + return mistralTokenizer; + } else if (oai_settings.openrouter_model.includes('gpt-4')) { return gpt4Tokenizer; } @@ -420,6 +427,8 @@ export function getTextTokens(tokenizerType, str) { return getTextTokensRemote('/api/tokenize/nerdstash', str); case tokenizers.NERD2: return getTextTokensRemote('/api/tokenize/nerdstash_v2', str); + case tokenizers.MISTRAL: + return getTextTokensRemote('/api/tokenize/mistral', str); case tokenizers.OPENAI: const model = getTokenizerModel(); return getTextTokensRemote('/api/tokenize/openai-encode', str, model); @@ -444,6 +453,8 @@ export function decodeTextTokens(tokenizerType, ids) { return decodeTextTokensRemote('/api/decode/nerdstash', ids); case tokenizers.NERD2: return decodeTextTokensRemote('/api/decode/nerdstash_v2', ids); + case tokenizers.MISTRAL: + return decodeTextTokensRemote('/api/decode/mistral', ids); default: console.warn("Calling decodeTextTokens with unsupported tokenizer type", tokenizerType); return ''; diff --git a/src/sentencepiece/tokenizer.model b/src/sentencepiece/llama.model similarity index 100% rename from src/sentencepiece/tokenizer.model rename to src/sentencepiece/llama.model diff --git a/src/sentencepiece/mistral.model b/src/sentencepiece/mistral.model new file mode 100644 index 000000000..85c0803f3 Binary files /dev/null and b/src/sentencepiece/mistral.model differ diff --git a/src/tokenizers.js b/src/tokenizers.js index 1ba3e20c3..7cc440e37 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -46,6 +46,7 @@ const CHARS_PER_TOKEN = 3.35; let spp_llama; let spp_nerd; let spp_nerd_v2; +let spp_mistral; let claude_tokenizer; async function loadSentencepieceTokenizer(modelPath) { @@ -91,6 +92,10 @@ function getTokenizerModel(requestModel) { return 'llama'; } + if (requestModel.includes('mistral')) { + return 'mistral'; + } + if (requestModel.includes('gpt-4-32k')) { return 'gpt-4-32k'; } @@ -247,10 +252,11 @@ function createTiktokenDecodingHandler(modelId) { * @returns {Promise} Promise that resolves when the tokenizers are loaded */ async function loadTokenizers() { - [spp_llama, spp_nerd, spp_nerd_v2, claude_tokenizer] = await Promise.all([ - loadSentencepieceTokenizer('src/sentencepiece/tokenizer.model'), + [spp_llama, spp_nerd, spp_nerd_v2, spp_mistral, claude_tokenizer] = await Promise.all([ + loadSentencepieceTokenizer('src/sentencepiece/llama.model'), loadSentencepieceTokenizer('src/sentencepiece/nerdstash.model'), loadSentencepieceTokenizer('src/sentencepiece/nerdstash_v2.model'), + loadSentencepieceTokenizer('src/sentencepiece/mistral.model'), loadClaudeTokenizer('src/claude.json'), ]); } @@ -286,10 +292,12 @@ function registerEndpoints(app, jsonParser) { app.post("/api/tokenize/llama", jsonParser, createSentencepieceEncodingHandler(() => spp_llama)); app.post("/api/tokenize/nerdstash", jsonParser, createSentencepieceEncodingHandler(() => spp_nerd)); app.post("/api/tokenize/nerdstash_v2", jsonParser, createSentencepieceEncodingHandler(() => spp_nerd_v2)); + app.post("/api/tokenize/mistral", jsonParser, createSentencepieceEncodingHandler(() => spp_mistral)); app.post("/api/tokenize/gpt2", jsonParser, createTiktokenEncodingHandler('gpt2')); app.post("/api/decode/llama", jsonParser, createSentencepieceDecodingHandler(() => spp_llama)); app.post("/api/decode/nerdstash", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd)); app.post("/api/decode/nerdstash_v2", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd_v2)); + app.post("/api/decode/mistral", jsonParser, createSentencepieceDecodingHandler(() => spp_mistral)); app.post("/api/decode/gpt2", jsonParser, createTiktokenDecodingHandler('gpt2')); app.post("/api/tokenize/openai-encode", jsonParser, async function (req, res) { @@ -301,6 +309,11 @@ function registerEndpoints(app, jsonParser) { return handler(req, res); } + if (queryModel.includes('mistral')) { + const handler = createSentencepieceEncodingHandler(() => spp_mistral); + return handler(req, res); + } + if (queryModel.includes('claude')) { const text = req.body.text || ''; const tokens = Object.values(claude_tokenizer.encode(text)); @@ -332,11 +345,17 @@ function registerEndpoints(app, jsonParser) { if (model == 'llama') { const jsonBody = req.body.flatMap(x => Object.values(x)).join('\n\n'); const llamaResult = await countSentencepieceTokens(spp_llama, jsonBody); - // console.log('jsonBody', jsonBody, 'llamaResult', llamaResult); num_tokens = llamaResult.count; return res.send({ "token_count": num_tokens }); } + if (model == 'mistral') { + const jsonBody = req.body.flatMap(x => Object.values(x)).join('\n\n'); + const mistralResult = await countSentencepieceTokens(spp_mistral, jsonBody); + num_tokens = mistralResult.count; + return res.send({ "token_count": num_tokens }); + } + const tokensPerName = queryModel.includes('gpt-3.5-turbo-0301') ? -1 : 1; const tokensPerMessage = queryModel.includes('gpt-3.5-turbo-0301') ? 4 : 3; const tokensPadding = 3;