From e8ba328a1497636d30e8db4f201dac70a54e4d9a Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Mon, 6 Nov 2023 02:42:51 +0200 Subject: [PATCH] Add text chunks display to token counter --- .../scripts/extensions/token-counter/index.js | 55 ++++++++++++++++-- .../extensions/token-counter/style.css | 4 ++ public/scripts/tokenizers.js | 5 ++ src/tokenizers.js | 57 +++++++++++++++---- 4 files changed, 105 insertions(+), 16 deletions(-) diff --git a/public/scripts/extensions/token-counter/index.js b/public/scripts/extensions/token-counter/index.js index c39240c0e..e0213def2 100644 --- a/public/scripts/extensions/token-counter/index.js +++ b/public/scripts/extensions/token-counter/index.js @@ -22,11 +22,15 @@ async function doTokenCounter() {

Type / paste in the box below to see the number of tokens in the text.

Selected tokenizer: ${selectedTokenizer}

- +
Input:
+
Tokens: 0

-
Token IDs (if applicable):
- +
Tokenized text:
+
+
+
Token IDs:
+
`; @@ -36,13 +40,18 @@ async function doTokenCounter() { const ids = main_api == 'openai' ? getTextTokens(tokenizers.OPENAI, text) : getTextTokens(tokenizerId, text); if (Array.isArray(ids) && ids.length > 0) { - $('#token_counter_ids').text(JSON.stringify(ids)); + $('#token_counter_ids').text(`[${ids.join(', ')}]`); $('#token_counter_result').text(ids.length); + + if (Object.hasOwnProperty.call(ids, 'chunks')) { + drawChunks(Object.getOwnPropertyDescriptor(ids, 'chunks').value, ids); + } } else { const context = getContext(); const count = context.getTokenCount(text); $('#token_counter_ids').text('—'); $('#token_counter_result').text(count); + $('#tokenized_chunks_display').text('—'); } }); @@ -50,6 +59,44 @@ async function doTokenCounter() { callPopup(dialog, 'text', '', { wide: true, large: true }); } +/** + * Draws the tokenized chunks in the UI + * @param {string[]} chunks + * @param {number[]} ids + */ +function drawChunks(chunks, ids) { + const pastelRainbow = [ + '#FFB3BA', + '#FFDFBA', + '#FFFFBA', + '#BFFFBF', + '#BAE1FF', + '#FFBAF3', + ]; + $('#tokenized_chunks_display').empty(); + + for (let i = 0; i < chunks.length; i++) { + let chunk = chunks[i].replace(/▁/g, ' '); // This is a leading space in sentencepiece. More info: Lower one eighth block (U+2581) + + // If <0xHEX>, decode it + if (/^<0x[0-9A-F]+>$/i.test(chunk)) { + const code = parseInt(chunk.substring(3, chunk.length - 1), 16); + chunk = String.fromCodePoint(code); + } + + // If newline - insert a line break + if (chunk === '\n') { + $('#tokenized_chunks_display').append('
'); + continue; + } + + const color = pastelRainbow[i % pastelRainbow.length]; + const chunkHtml = $(`${chunk}`); + chunkHtml.attr('title', ids[i]); + $('#tokenized_chunks_display').append(chunkHtml); + } +} + function doCount() { // get all of the messages in the chat const context = getContext(); diff --git a/public/scripts/extensions/token-counter/style.css b/public/scripts/extensions/token-counter/style.css index e69de29bb..5af284a13 100644 --- a/public/scripts/extensions/token-counter/style.css +++ b/public/scripts/extensions/token-counter/style.css @@ -0,0 +1,4 @@ +#tokenized_chunks_display > code { + color: black; + text-shadow: none; +} diff --git a/public/scripts/tokenizers.js b/public/scripts/tokenizers.js index e78786646..a97e7dc01 100644 --- a/public/scripts/tokenizers.js +++ b/public/scripts/tokenizers.js @@ -385,6 +385,11 @@ function getTextTokensRemote(endpoint, str, model = '') { contentType: "application/json", success: function (data) { ids = data.ids; + + // Don't want to break reverse compatibility, so sprinkle in some of the JS magic + if (Array.isArray(data.chunks)) { + Object.defineProperty(ids, 'chunks', { value: data.chunks }); + } } }); return ids; diff --git a/src/tokenizers.js b/src/tokenizers.js index 7cc440e37..264e0706d 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -78,6 +78,39 @@ async function countSentencepieceTokens(spp, text) { }; } +async function countSentencepieceArrayTokens(tokenizer, array) { + const jsonBody = array.flatMap(x => Object.values(x)).join('\n\n'); + const result = await countSentencepieceTokens(tokenizer, jsonBody); + const num_tokens = result.count; + return num_tokens; +} + +async function getTiktokenChunks(tokenizer, ids) { + const decoder = new TextDecoder(); + const chunks = []; + + for (let i = 0; i < ids.length; i++) { + const id = ids[i]; + const chunkTextBytes = await tokenizer.decode(new Uint32Array([id])); + const chunkText = decoder.decode(chunkTextBytes); + chunks.push(chunkText); + } + + return chunks; +} + +async function getWebTokenizersChunks(tokenizer, ids) { + const chunks = []; + + for (let i = 0; i < ids.length; i++) { + const id = ids[i]; + const chunkText = await tokenizer.decode(new Uint32Array([id])); + chunks.push(chunkText); + } + + return chunks; +} + /** * Gets the tokenizer model by the model name. * @param {string} requestModel Models to use for tokenization @@ -169,10 +202,11 @@ function createSentencepieceEncodingHandler(getTokenizerFn) { const text = request.body.text || ''; const tokenizer = getTokenizerFn(); const { ids, count } = await countSentencepieceTokens(tokenizer, text); - return response.send({ ids, count }); + const chunks = await tokenizer.encodePieces(text); + return response.send({ ids, count, chunks }); } catch (error) { console.log(error); - return response.send({ ids: [], count: 0 }); + return response.send({ ids: [], count: 0, chunks: [] }); } }; } @@ -215,10 +249,11 @@ function createTiktokenEncodingHandler(modelId) { const text = request.body.text || ''; const tokenizer = getTiktokenTokenizer(modelId); const tokens = Object.values(tokenizer.encode(text)); - return response.send({ ids: tokens, count: tokens.length }); + const chunks = await getTiktokenChunks(tokenizer, tokens); + return response.send({ ids: tokens, count: tokens.length, chunks }); } catch (error) { console.log(error); - return response.send({ ids: [], count: 0 }); + return response.send({ ids: [], count: 0, chunks: [] }); } } } @@ -317,7 +352,8 @@ function registerEndpoints(app, jsonParser) { if (queryModel.includes('claude')) { const text = req.body.text || ''; const tokens = Object.values(claude_tokenizer.encode(text)); - return res.send({ ids: tokens, count: tokens.length }); + const chunks = await getWebTokenizersChunks(claude_tokenizer, tokens); + return res.send({ ids: tokens, count: tokens.length, chunks }); } const model = getTokenizerModel(queryModel); @@ -325,7 +361,7 @@ function registerEndpoints(app, jsonParser) { return handler(req, res); } catch (error) { console.log(error); - return res.send({ ids: [], count: 0 }); + return res.send({ ids: [], count: 0, chunks: [] }); } }); @@ -343,16 +379,12 @@ function registerEndpoints(app, jsonParser) { } if (model == 'llama') { - const jsonBody = req.body.flatMap(x => Object.values(x)).join('\n\n'); - const llamaResult = await countSentencepieceTokens(spp_llama, jsonBody); - num_tokens = llamaResult.count; + num_tokens = await countSentencepieceArrayTokens(spp_llama, req.body); return res.send({ "token_count": num_tokens }); } if (model == 'mistral') { - const jsonBody = req.body.flatMap(x => Object.values(x)).join('\n\n'); - const mistralResult = await countSentencepieceTokens(spp_mistral, jsonBody); - num_tokens = mistralResult.count; + num_tokens = await countSentencepieceArrayTokens(spp_mistral, req.body); return res.send({ "token_count": num_tokens }); } @@ -407,3 +439,4 @@ module.exports = { loadTokenizers, registerEndpoints, } +