Add text chunks display to token counter

This commit is contained in:
Cohee
2023-11-06 02:42:51 +02:00
parent f248367ca3
commit e8ba328a14
4 changed files with 105 additions and 16 deletions

View File

@@ -78,6 +78,39 @@ async function countSentencepieceTokens(spp, text) {
};
}
async function countSentencepieceArrayTokens(tokenizer, array) {
const jsonBody = array.flatMap(x => Object.values(x)).join('\n\n');
const result = await countSentencepieceTokens(tokenizer, jsonBody);
const num_tokens = result.count;
return num_tokens;
}
async function getTiktokenChunks(tokenizer, ids) {
const decoder = new TextDecoder();
const chunks = [];
for (let i = 0; i < ids.length; i++) {
const id = ids[i];
const chunkTextBytes = await tokenizer.decode(new Uint32Array([id]));
const chunkText = decoder.decode(chunkTextBytes);
chunks.push(chunkText);
}
return chunks;
}
async function getWebTokenizersChunks(tokenizer, ids) {
const chunks = [];
for (let i = 0; i < ids.length; i++) {
const id = ids[i];
const chunkText = await tokenizer.decode(new Uint32Array([id]));
chunks.push(chunkText);
}
return chunks;
}
/**
* Gets the tokenizer model by the model name.
* @param {string} requestModel Models to use for tokenization
@@ -169,10 +202,11 @@ function createSentencepieceEncodingHandler(getTokenizerFn) {
const text = request.body.text || '';
const tokenizer = getTokenizerFn();
const { ids, count } = await countSentencepieceTokens(tokenizer, text);
return response.send({ ids, count });
const chunks = await tokenizer.encodePieces(text);
return response.send({ ids, count, chunks });
} catch (error) {
console.log(error);
return response.send({ ids: [], count: 0 });
return response.send({ ids: [], count: 0, chunks: [] });
}
};
}
@@ -215,10 +249,11 @@ function createTiktokenEncodingHandler(modelId) {
const text = request.body.text || '';
const tokenizer = getTiktokenTokenizer(modelId);
const tokens = Object.values(tokenizer.encode(text));
return response.send({ ids: tokens, count: tokens.length });
const chunks = await getTiktokenChunks(tokenizer, tokens);
return response.send({ ids: tokens, count: tokens.length, chunks });
} catch (error) {
console.log(error);
return response.send({ ids: [], count: 0 });
return response.send({ ids: [], count: 0, chunks: [] });
}
}
}
@@ -317,7 +352,8 @@ function registerEndpoints(app, jsonParser) {
if (queryModel.includes('claude')) {
const text = req.body.text || '';
const tokens = Object.values(claude_tokenizer.encode(text));
return res.send({ ids: tokens, count: tokens.length });
const chunks = await getWebTokenizersChunks(claude_tokenizer, tokens);
return res.send({ ids: tokens, count: tokens.length, chunks });
}
const model = getTokenizerModel(queryModel);
@@ -325,7 +361,7 @@ function registerEndpoints(app, jsonParser) {
return handler(req, res);
} catch (error) {
console.log(error);
return res.send({ ids: [], count: 0 });
return res.send({ ids: [], count: 0, chunks: [] });
}
});
@@ -343,16 +379,12 @@ function registerEndpoints(app, jsonParser) {
}
if (model == 'llama') {
const jsonBody = req.body.flatMap(x => Object.values(x)).join('\n\n');
const llamaResult = await countSentencepieceTokens(spp_llama, jsonBody);
num_tokens = llamaResult.count;
num_tokens = await countSentencepieceArrayTokens(spp_llama, req.body);
return res.send({ "token_count": num_tokens });
}
if (model == 'mistral') {
const jsonBody = req.body.flatMap(x => Object.values(x)).join('\n\n');
const mistralResult = await countSentencepieceTokens(spp_mistral, jsonBody);
num_tokens = mistralResult.count;
num_tokens = await countSentencepieceArrayTokens(spp_mistral, req.body);
return res.send({ "token_count": num_tokens });
}
@@ -407,3 +439,4 @@ module.exports = {
loadTokenizers,
registerEndpoints,
}