diff --git a/public/index.html b/public/index.html
index af512ce43..ecad679cc 100644
--- a/public/index.html
+++ b/public/index.html
@@ -2306,6 +2306,7 @@
+
diff --git a/public/scripts/tokenizers.js b/public/scripts/tokenizers.js
index cbff17f6d..c6ab38cad 100644
--- a/public/scripts/tokenizers.js
+++ b/public/scripts/tokenizers.js
@@ -18,6 +18,7 @@ export const tokenizers = {
NERD2: 5,
API: 6,
MISTRAL: 7,
+ YI: 8,
BEST_MATCH: 99,
};
@@ -148,6 +149,8 @@ function callTokenizer(type, str, padding) {
return countTokensRemote('/api/tokenize/nerdstash_v2', str, padding);
case tokenizers.MISTRAL:
return countTokensRemote('/api/tokenize/mistral', str, padding);
+ case tokenizers.YI:
+ return countTokensRemote('/api/tokenize/yi', str, padding);
case tokenizers.API:
return countTokensRemote('/tokenize_via_api', str, padding);
default:
@@ -229,6 +232,7 @@ export function getTokenizerModel() {
const claudeTokenizer = 'claude';
const llamaTokenizer = 'llama';
const mistralTokenizer = 'mistral';
+ const yiTokenizer = 'yi';
// Assuming no one would use it for different models.. right?
if (oai_settings.chat_completion_source == chat_completion_sources.SCALE) {
@@ -264,6 +268,9 @@ export function getTokenizerModel() {
else if (model?.architecture?.tokenizer === 'Mistral') {
return mistralTokenizer;
}
+ else if (model?.architecture?.tokenizer === 'Yi') {
+ return yiTokenizer;
+ }
else if (oai_settings.openrouter_model.includes('gpt-4')) {
return gpt4Tokenizer;
}
@@ -485,6 +492,8 @@ export function getTextTokens(tokenizerType, str) {
return getTextTokensRemote('/api/tokenize/nerdstash_v2', str);
case tokenizers.MISTRAL:
return getTextTokensRemote('/api/tokenize/mistral', str);
+ case tokenizers.YI:
+ return getTextTokensRemote('/api/tokenize/yi', str);
case tokenizers.OPENAI:
const model = getTokenizerModel();
return getTextTokensRemote('/api/tokenize/openai-encode', str, model);
@@ -513,6 +522,8 @@ export function decodeTextTokens(tokenizerType, ids) {
return decodeTextTokensRemote('/api/decode/nerdstash_v2', ids);
case tokenizers.MISTRAL:
return decodeTextTokensRemote('/api/decode/mistral', ids);
+ case tokenizers.YI:
+ return decodeTextTokensRemote('/api/decode/yi', ids);
default:
console.warn("Calling decodeTextTokens with unsupported tokenizer type", tokenizerType);
return '';
diff --git a/src/sentencepiece/yi.model b/src/sentencepiece/yi.model
new file mode 100644
index 000000000..0c3136e08
Binary files /dev/null and b/src/sentencepiece/yi.model differ
diff --git a/src/tokenizers.js b/src/tokenizers.js
index 72cb4101c..c1bd90f75 100644
--- a/src/tokenizers.js
+++ b/src/tokenizers.js
@@ -76,6 +76,7 @@ const spp_llama = new SentencePieceTokenizer('src/sentencepiece/llama.model');
const spp_nerd = new SentencePieceTokenizer('src/sentencepiece/nerdstash.model');
const spp_nerd_v2 = new SentencePieceTokenizer('src/sentencepiece/nerdstash_v2.model');
const spp_mistral = new SentencePieceTokenizer('src/sentencepiece/mistral.model');
+const spp_yi = new SentencePieceTokenizer('src/sentencepiece/yi.model');
let claude_tokenizer;
const sentencepieceTokenizers = [
@@ -181,18 +182,6 @@ async function getWebTokenizersChunks(tokenizer, ids) {
* @returns {string} Tokenizer model to use
*/
function getTokenizerModel(requestModel) {
- if (requestModel.includes('claude')) {
- return 'claude';
- }
-
- if (requestModel.includes('llama')) {
- return 'llama';
- }
-
- if (requestModel.includes('mistral')) {
- return 'mistral';
- }
-
if (requestModel.includes('gpt-4-32k')) {
return 'gpt-4-32k';
}
@@ -213,6 +202,22 @@ function getTokenizerModel(requestModel) {
return requestModel;
}
+ if (requestModel.includes('claude')) {
+ return 'claude';
+ }
+
+ if (requestModel.includes('llama')) {
+ return 'llama';
+ }
+
+ if (requestModel.includes('mistral')) {
+ return 'mistral';
+ }
+
+ if (requestModel.includes('yi')) {
+ return 'yi';
+ }
+
// default
return 'gpt-3.5-turbo';
}
@@ -386,11 +391,13 @@ function registerEndpoints(app, jsonParser) {
app.post("/api/tokenize/nerdstash", jsonParser, createSentencepieceEncodingHandler(spp_nerd));
app.post("/api/tokenize/nerdstash_v2", jsonParser, createSentencepieceEncodingHandler(spp_nerd_v2));
app.post("/api/tokenize/mistral", jsonParser, createSentencepieceEncodingHandler(spp_mistral));
+ app.post("/api/tokenize/yi", jsonParser, createSentencepieceEncodingHandler(spp_yi));
app.post("/api/tokenize/gpt2", jsonParser, createTiktokenEncodingHandler('gpt2'));
app.post("/api/decode/llama", jsonParser, createSentencepieceDecodingHandler(spp_llama));
app.post("/api/decode/nerdstash", jsonParser, createSentencepieceDecodingHandler(spp_nerd));
app.post("/api/decode/nerdstash_v2", jsonParser, createSentencepieceDecodingHandler(spp_nerd_v2));
app.post("/api/decode/mistral", jsonParser, createSentencepieceDecodingHandler(spp_mistral));
+ app.post("/api/decode/yi", jsonParser, createSentencepieceDecodingHandler(spp_yi));
app.post("/api/decode/gpt2", jsonParser, createTiktokenDecodingHandler('gpt2'));
app.post("/api/tokenize/openai-encode", jsonParser, async function (req, res) {
@@ -407,6 +414,11 @@ function registerEndpoints(app, jsonParser) {
return handler(req, res);
}
+ if (queryModel.includes('yi')) {
+ const handler = createSentencepieceEncodingHandler(spp_yi);
+ return handler(req, res);
+ }
+
if (queryModel.includes('claude')) {
const text = req.body.text || '';
const tokens = Object.values(claude_tokenizer.encode(text));
@@ -431,21 +443,26 @@ function registerEndpoints(app, jsonParser) {
const queryModel = String(req.query.model || '');
const model = getTokenizerModel(queryModel);
- if (model == 'claude') {
+ if (model === 'claude') {
num_tokens = countClaudeTokens(claude_tokenizer, req.body);
return res.send({ "token_count": num_tokens });
}
- if (model == 'llama') {
+ if (model === 'llama') {
num_tokens = await countSentencepieceArrayTokens(spp_llama, req.body);
return res.send({ "token_count": num_tokens });
}
- if (model == 'mistral') {
+ if (model === 'mistral') {
num_tokens = await countSentencepieceArrayTokens(spp_mistral, req.body);
return res.send({ "token_count": num_tokens });
}
+ if (model === 'yi') {
+ num_tokens = await countSentencepieceArrayTokens(spp_yi, req.body);
+ return res.send({ "token_count": num_tokens });
+ }
+
const tokensPerName = queryModel.includes('gpt-3.5-turbo-0301') ? -1 : 1;
const tokensPerMessage = queryModel.includes('gpt-3.5-turbo-0301') ? 4 : 3;
const tokensPadding = 3;