Add text chunks display to token counter

This commit is contained in:
Cohee 2023-11-06 02:42:51 +02:00
parent f248367ca3
commit e8ba328a14
4 changed files with 105 additions and 16 deletions

View File

@ -22,11 +22,15 @@ async function doTokenCounter() {
<div class="justifyLeft">
<h4>Type / paste in the box below to see the number of tokens in the text.</h4>
<p>Selected tokenizer: ${selectedTokenizer}</p>
<textarea id="token_counter_textarea" class="wide100p textarea_compact margin-bot-10px" rows="15"></textarea>
<div>Input:</div>
<textarea id="token_counter_textarea" class="wide100p textarea_compact margin-bot-10px" rows="10"></textarea>
<div>Tokens: <span id="token_counter_result">0</span></div>
<br>
<div>Token IDs (if applicable):</div>
<textarea id="token_counter_ids" disabled rows="10"></textarea>
<div>Tokenized text:</div>
<div id="tokenized_chunks_display" class="wide100p"></div>
<br>
<div>Token IDs:</div>
<textarea id="token_counter_ids" disabled rows="10"></textarea>
</div>
</div>`;
@ -36,13 +40,18 @@ async function doTokenCounter() {
const ids = main_api == 'openai' ? getTextTokens(tokenizers.OPENAI, text) : getTextTokens(tokenizerId, text);
if (Array.isArray(ids) && ids.length > 0) {
$('#token_counter_ids').text(JSON.stringify(ids));
$('#token_counter_ids').text(`[${ids.join(', ')}]`);
$('#token_counter_result').text(ids.length);
if (Object.hasOwnProperty.call(ids, 'chunks')) {
drawChunks(Object.getOwnPropertyDescriptor(ids, 'chunks').value, ids);
}
} else {
const context = getContext();
const count = context.getTokenCount(text);
$('#token_counter_ids').text('—');
$('#token_counter_result').text(count);
$('#tokenized_chunks_display').text('—');
}
});
@ -50,6 +59,44 @@ async function doTokenCounter() {
callPopup(dialog, 'text', '', { wide: true, large: true });
}
/**
* Draws the tokenized chunks in the UI
* @param {string[]} chunks
* @param {number[]} ids
*/
function drawChunks(chunks, ids) {
const pastelRainbow = [
'#FFB3BA',
'#FFDFBA',
'#FFFFBA',
'#BFFFBF',
'#BAE1FF',
'#FFBAF3',
];
$('#tokenized_chunks_display').empty();
for (let i = 0; i < chunks.length; i++) {
let chunk = chunks[i].replace(/▁/g, ' '); // This is a leading space in sentencepiece. More info: Lower one eighth block (U+2581)
// If <0xHEX>, decode it
if (/^<0x[0-9A-F]+>$/i.test(chunk)) {
const code = parseInt(chunk.substring(3, chunk.length - 1), 16);
chunk = String.fromCodePoint(code);
}
// If newline - insert a line break
if (chunk === '\n') {
$('#tokenized_chunks_display').append('<br>');
continue;
}
const color = pastelRainbow[i % pastelRainbow.length];
const chunkHtml = $(`<code style="background-color: ${color};">${chunk}</code>`);
chunkHtml.attr('title', ids[i]);
$('#tokenized_chunks_display').append(chunkHtml);
}
}
function doCount() {
// get all of the messages in the chat
const context = getContext();

View File

@ -0,0 +1,4 @@
#tokenized_chunks_display > code {
color: black;
text-shadow: none;
}

View File

@ -385,6 +385,11 @@ function getTextTokensRemote(endpoint, str, model = '') {
contentType: "application/json",
success: function (data) {
ids = data.ids;
// Don't want to break reverse compatibility, so sprinkle in some of the JS magic
if (Array.isArray(data.chunks)) {
Object.defineProperty(ids, 'chunks', { value: data.chunks });
}
}
});
return ids;

View File

@ -78,6 +78,39 @@ async function countSentencepieceTokens(spp, text) {
};
}
async function countSentencepieceArrayTokens(tokenizer, array) {
const jsonBody = array.flatMap(x => Object.values(x)).join('\n\n');
const result = await countSentencepieceTokens(tokenizer, jsonBody);
const num_tokens = result.count;
return num_tokens;
}
async function getTiktokenChunks(tokenizer, ids) {
const decoder = new TextDecoder();
const chunks = [];
for (let i = 0; i < ids.length; i++) {
const id = ids[i];
const chunkTextBytes = await tokenizer.decode(new Uint32Array([id]));
const chunkText = decoder.decode(chunkTextBytes);
chunks.push(chunkText);
}
return chunks;
}
async function getWebTokenizersChunks(tokenizer, ids) {
const chunks = [];
for (let i = 0; i < ids.length; i++) {
const id = ids[i];
const chunkText = await tokenizer.decode(new Uint32Array([id]));
chunks.push(chunkText);
}
return chunks;
}
/**
* Gets the tokenizer model by the model name.
* @param {string} requestModel Models to use for tokenization
@ -169,10 +202,11 @@ function createSentencepieceEncodingHandler(getTokenizerFn) {
const text = request.body.text || '';
const tokenizer = getTokenizerFn();
const { ids, count } = await countSentencepieceTokens(tokenizer, text);
return response.send({ ids, count });
const chunks = await tokenizer.encodePieces(text);
return response.send({ ids, count, chunks });
} catch (error) {
console.log(error);
return response.send({ ids: [], count: 0 });
return response.send({ ids: [], count: 0, chunks: [] });
}
};
}
@ -215,10 +249,11 @@ function createTiktokenEncodingHandler(modelId) {
const text = request.body.text || '';
const tokenizer = getTiktokenTokenizer(modelId);
const tokens = Object.values(tokenizer.encode(text));
return response.send({ ids: tokens, count: tokens.length });
const chunks = await getTiktokenChunks(tokenizer, tokens);
return response.send({ ids: tokens, count: tokens.length, chunks });
} catch (error) {
console.log(error);
return response.send({ ids: [], count: 0 });
return response.send({ ids: [], count: 0, chunks: [] });
}
}
}
@ -317,7 +352,8 @@ function registerEndpoints(app, jsonParser) {
if (queryModel.includes('claude')) {
const text = req.body.text || '';
const tokens = Object.values(claude_tokenizer.encode(text));
return res.send({ ids: tokens, count: tokens.length });
const chunks = await getWebTokenizersChunks(claude_tokenizer, tokens);
return res.send({ ids: tokens, count: tokens.length, chunks });
}
const model = getTokenizerModel(queryModel);
@ -325,7 +361,7 @@ function registerEndpoints(app, jsonParser) {
return handler(req, res);
} catch (error) {
console.log(error);
return res.send({ ids: [], count: 0 });
return res.send({ ids: [], count: 0, chunks: [] });
}
});
@ -343,16 +379,12 @@ function registerEndpoints(app, jsonParser) {
}
if (model == 'llama') {
const jsonBody = req.body.flatMap(x => Object.values(x)).join('\n\n');
const llamaResult = await countSentencepieceTokens(spp_llama, jsonBody);
num_tokens = llamaResult.count;
num_tokens = await countSentencepieceArrayTokens(spp_llama, req.body);
return res.send({ "token_count": num_tokens });
}
if (model == 'mistral') {
const jsonBody = req.body.flatMap(x => Object.values(x)).join('\n\n');
const mistralResult = await countSentencepieceTokens(spp_mistral, jsonBody);
num_tokens = mistralResult.count;
num_tokens = await countSentencepieceArrayTokens(spp_mistral, req.body);
return res.send({ "token_count": num_tokens });
}
@ -407,3 +439,4 @@ module.exports = {
loadTokenizers,
registerEndpoints,
}