diff --git a/public/scripts/extensions/token-counter/index.js b/public/scripts/extensions/token-counter/index.js index c39240c0e..e0213def2 100644 --- a/public/scripts/extensions/token-counter/index.js +++ b/public/scripts/extensions/token-counter/index.js @@ -22,11 +22,15 @@ async function doTokenCounter() {
Selected tokenizer: ${selectedTokenizer}
- +${chunk}
`);
+ chunkHtml.attr('title', ids[i]);
+ $('#tokenized_chunks_display').append(chunkHtml);
+ }
+}
+
function doCount() {
// get all of the messages in the chat
const context = getContext();
diff --git a/public/scripts/extensions/token-counter/style.css b/public/scripts/extensions/token-counter/style.css
index e69de29bb..5af284a13 100644
--- a/public/scripts/extensions/token-counter/style.css
+++ b/public/scripts/extensions/token-counter/style.css
@@ -0,0 +1,4 @@
+#tokenized_chunks_display > code {
+ color: black;
+ text-shadow: none;
+}
diff --git a/public/scripts/tokenizers.js b/public/scripts/tokenizers.js
index e78786646..a97e7dc01 100644
--- a/public/scripts/tokenizers.js
+++ b/public/scripts/tokenizers.js
@@ -385,6 +385,11 @@ function getTextTokensRemote(endpoint, str, model = '') {
contentType: "application/json",
success: function (data) {
ids = data.ids;
+
+ // Don't want to break reverse compatibility, so sprinkle in some of the JS magic
+ if (Array.isArray(data.chunks)) {
+ Object.defineProperty(ids, 'chunks', { value: data.chunks });
+ }
}
});
return ids;
diff --git a/src/tokenizers.js b/src/tokenizers.js
index 7cc440e37..264e0706d 100644
--- a/src/tokenizers.js
+++ b/src/tokenizers.js
@@ -78,6 +78,39 @@ async function countSentencepieceTokens(spp, text) {
};
}
+async function countSentencepieceArrayTokens(tokenizer, array) {
+ const jsonBody = array.flatMap(x => Object.values(x)).join('\n\n');
+ const result = await countSentencepieceTokens(tokenizer, jsonBody);
+ const num_tokens = result.count;
+ return num_tokens;
+}
+
+async function getTiktokenChunks(tokenizer, ids) {
+ const decoder = new TextDecoder();
+ const chunks = [];
+
+ for (let i = 0; i < ids.length; i++) {
+ const id = ids[i];
+ const chunkTextBytes = await tokenizer.decode(new Uint32Array([id]));
+ const chunkText = decoder.decode(chunkTextBytes);
+ chunks.push(chunkText);
+ }
+
+ return chunks;
+}
+
+async function getWebTokenizersChunks(tokenizer, ids) {
+ const chunks = [];
+
+ for (let i = 0; i < ids.length; i++) {
+ const id = ids[i];
+ const chunkText = await tokenizer.decode(new Uint32Array([id]));
+ chunks.push(chunkText);
+ }
+
+ return chunks;
+}
+
/**
* Gets the tokenizer model by the model name.
* @param {string} requestModel Models to use for tokenization
@@ -169,10 +202,11 @@ function createSentencepieceEncodingHandler(getTokenizerFn) {
const text = request.body.text || '';
const tokenizer = getTokenizerFn();
const { ids, count } = await countSentencepieceTokens(tokenizer, text);
- return response.send({ ids, count });
+ const chunks = await tokenizer.encodePieces(text);
+ return response.send({ ids, count, chunks });
} catch (error) {
console.log(error);
- return response.send({ ids: [], count: 0 });
+ return response.send({ ids: [], count: 0, chunks: [] });
}
};
}
@@ -215,10 +249,11 @@ function createTiktokenEncodingHandler(modelId) {
const text = request.body.text || '';
const tokenizer = getTiktokenTokenizer(modelId);
const tokens = Object.values(tokenizer.encode(text));
- return response.send({ ids: tokens, count: tokens.length });
+ const chunks = await getTiktokenChunks(tokenizer, tokens);
+ return response.send({ ids: tokens, count: tokens.length, chunks });
} catch (error) {
console.log(error);
- return response.send({ ids: [], count: 0 });
+ return response.send({ ids: [], count: 0, chunks: [] });
}
}
}
@@ -317,7 +352,8 @@ function registerEndpoints(app, jsonParser) {
if (queryModel.includes('claude')) {
const text = req.body.text || '';
const tokens = Object.values(claude_tokenizer.encode(text));
- return res.send({ ids: tokens, count: tokens.length });
+ const chunks = await getWebTokenizersChunks(claude_tokenizer, tokens);
+ return res.send({ ids: tokens, count: tokens.length, chunks });
}
const model = getTokenizerModel(queryModel);
@@ -325,7 +361,7 @@ function registerEndpoints(app, jsonParser) {
return handler(req, res);
} catch (error) {
console.log(error);
- return res.send({ ids: [], count: 0 });
+ return res.send({ ids: [], count: 0, chunks: [] });
}
});
@@ -343,16 +379,12 @@ function registerEndpoints(app, jsonParser) {
}
if (model == 'llama') {
- const jsonBody = req.body.flatMap(x => Object.values(x)).join('\n\n');
- const llamaResult = await countSentencepieceTokens(spp_llama, jsonBody);
- num_tokens = llamaResult.count;
+ num_tokens = await countSentencepieceArrayTokens(spp_llama, req.body);
return res.send({ "token_count": num_tokens });
}
if (model == 'mistral') {
- const jsonBody = req.body.flatMap(x => Object.values(x)).join('\n\n');
- const mistralResult = await countSentencepieceTokens(spp_mistral, jsonBody);
- num_tokens = mistralResult.count;
+ num_tokens = await countSentencepieceArrayTokens(spp_mistral, req.body);
return res.send({ "token_count": num_tokens });
}
@@ -407,3 +439,4 @@ module.exports = {
loadTokenizers,
registerEndpoints,
}
+