diff --git a/public/index.html b/public/index.html index af512ce43..ecad679cc 100644 --- a/public/index.html +++ b/public/index.html @@ -2306,6 +2306,7 @@ + diff --git a/public/scripts/tokenizers.js b/public/scripts/tokenizers.js index cbff17f6d..c6ab38cad 100644 --- a/public/scripts/tokenizers.js +++ b/public/scripts/tokenizers.js @@ -18,6 +18,7 @@ export const tokenizers = { NERD2: 5, API: 6, MISTRAL: 7, + YI: 8, BEST_MATCH: 99, }; @@ -148,6 +149,8 @@ function callTokenizer(type, str, padding) { return countTokensRemote('/api/tokenize/nerdstash_v2', str, padding); case tokenizers.MISTRAL: return countTokensRemote('/api/tokenize/mistral', str, padding); + case tokenizers.YI: + return countTokensRemote('/api/tokenize/yi', str, padding); case tokenizers.API: return countTokensRemote('/tokenize_via_api', str, padding); default: @@ -229,6 +232,7 @@ export function getTokenizerModel() { const claudeTokenizer = 'claude'; const llamaTokenizer = 'llama'; const mistralTokenizer = 'mistral'; + const yiTokenizer = 'yi'; // Assuming no one would use it for different models.. right? if (oai_settings.chat_completion_source == chat_completion_sources.SCALE) { @@ -264,6 +268,9 @@ export function getTokenizerModel() { else if (model?.architecture?.tokenizer === 'Mistral') { return mistralTokenizer; } + else if (model?.architecture?.tokenizer === 'Yi') { + return yiTokenizer; + } else if (oai_settings.openrouter_model.includes('gpt-4')) { return gpt4Tokenizer; } @@ -485,6 +492,8 @@ export function getTextTokens(tokenizerType, str) { return getTextTokensRemote('/api/tokenize/nerdstash_v2', str); case tokenizers.MISTRAL: return getTextTokensRemote('/api/tokenize/mistral', str); + case tokenizers.YI: + return getTextTokensRemote('/api/tokenize/yi', str); case tokenizers.OPENAI: const model = getTokenizerModel(); return getTextTokensRemote('/api/tokenize/openai-encode', str, model); @@ -513,6 +522,8 @@ export function decodeTextTokens(tokenizerType, ids) { return decodeTextTokensRemote('/api/decode/nerdstash_v2', ids); case tokenizers.MISTRAL: return decodeTextTokensRemote('/api/decode/mistral', ids); + case tokenizers.YI: + return decodeTextTokensRemote('/api/decode/yi', ids); default: console.warn("Calling decodeTextTokens with unsupported tokenizer type", tokenizerType); return ''; diff --git a/src/sentencepiece/yi.model b/src/sentencepiece/yi.model new file mode 100644 index 000000000..0c3136e08 Binary files /dev/null and b/src/sentencepiece/yi.model differ diff --git a/src/tokenizers.js b/src/tokenizers.js index 72cb4101c..c1bd90f75 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -76,6 +76,7 @@ const spp_llama = new SentencePieceTokenizer('src/sentencepiece/llama.model'); const spp_nerd = new SentencePieceTokenizer('src/sentencepiece/nerdstash.model'); const spp_nerd_v2 = new SentencePieceTokenizer('src/sentencepiece/nerdstash_v2.model'); const spp_mistral = new SentencePieceTokenizer('src/sentencepiece/mistral.model'); +const spp_yi = new SentencePieceTokenizer('src/sentencepiece/yi.model'); let claude_tokenizer; const sentencepieceTokenizers = [ @@ -181,18 +182,6 @@ async function getWebTokenizersChunks(tokenizer, ids) { * @returns {string} Tokenizer model to use */ function getTokenizerModel(requestModel) { - if (requestModel.includes('claude')) { - return 'claude'; - } - - if (requestModel.includes('llama')) { - return 'llama'; - } - - if (requestModel.includes('mistral')) { - return 'mistral'; - } - if (requestModel.includes('gpt-4-32k')) { return 'gpt-4-32k'; } @@ -213,6 +202,22 @@ function getTokenizerModel(requestModel) { return requestModel; } + if (requestModel.includes('claude')) { + return 'claude'; + } + + if (requestModel.includes('llama')) { + return 'llama'; + } + + if (requestModel.includes('mistral')) { + return 'mistral'; + } + + if (requestModel.includes('yi')) { + return 'yi'; + } + // default return 'gpt-3.5-turbo'; } @@ -386,11 +391,13 @@ function registerEndpoints(app, jsonParser) { app.post("/api/tokenize/nerdstash", jsonParser, createSentencepieceEncodingHandler(spp_nerd)); app.post("/api/tokenize/nerdstash_v2", jsonParser, createSentencepieceEncodingHandler(spp_nerd_v2)); app.post("/api/tokenize/mistral", jsonParser, createSentencepieceEncodingHandler(spp_mistral)); + app.post("/api/tokenize/yi", jsonParser, createSentencepieceEncodingHandler(spp_yi)); app.post("/api/tokenize/gpt2", jsonParser, createTiktokenEncodingHandler('gpt2')); app.post("/api/decode/llama", jsonParser, createSentencepieceDecodingHandler(spp_llama)); app.post("/api/decode/nerdstash", jsonParser, createSentencepieceDecodingHandler(spp_nerd)); app.post("/api/decode/nerdstash_v2", jsonParser, createSentencepieceDecodingHandler(spp_nerd_v2)); app.post("/api/decode/mistral", jsonParser, createSentencepieceDecodingHandler(spp_mistral)); + app.post("/api/decode/yi", jsonParser, createSentencepieceDecodingHandler(spp_yi)); app.post("/api/decode/gpt2", jsonParser, createTiktokenDecodingHandler('gpt2')); app.post("/api/tokenize/openai-encode", jsonParser, async function (req, res) { @@ -407,6 +414,11 @@ function registerEndpoints(app, jsonParser) { return handler(req, res); } + if (queryModel.includes('yi')) { + const handler = createSentencepieceEncodingHandler(spp_yi); + return handler(req, res); + } + if (queryModel.includes('claude')) { const text = req.body.text || ''; const tokens = Object.values(claude_tokenizer.encode(text)); @@ -431,21 +443,26 @@ function registerEndpoints(app, jsonParser) { const queryModel = String(req.query.model || ''); const model = getTokenizerModel(queryModel); - if (model == 'claude') { + if (model === 'claude') { num_tokens = countClaudeTokens(claude_tokenizer, req.body); return res.send({ "token_count": num_tokens }); } - if (model == 'llama') { + if (model === 'llama') { num_tokens = await countSentencepieceArrayTokens(spp_llama, req.body); return res.send({ "token_count": num_tokens }); } - if (model == 'mistral') { + if (model === 'mistral') { num_tokens = await countSentencepieceArrayTokens(spp_mistral, req.body); return res.send({ "token_count": num_tokens }); } + if (model === 'yi') { + num_tokens = await countSentencepieceArrayTokens(spp_yi, req.body); + return res.send({ "token_count": num_tokens }); + } + const tokensPerName = queryModel.includes('gpt-3.5-turbo-0301') ? -1 : 1; const tokensPerMessage = queryModel.includes('gpt-3.5-turbo-0301') ? 4 : 3; const tokensPadding = 3;