Add lazy loading of sentencepiece tokenizers

This commit is contained in:
Cohee 2023-11-15 19:39:55 +02:00
parent 9199750afe
commit 3fb26d3927
2 changed files with 73 additions and 48 deletions

View File

@ -2795,13 +2795,13 @@ app.post("/openai_bias", jsonParser, async function (request, response) {
if (sentencepieceTokenizers.includes(model)) {
const tokenizer = getSentencepiceTokenizer(model);
encodeFunction = (text) => new Uint32Array(tokenizer.encodeIds(text));
const instance = await tokenizer?.get();
encodeFunction = (text) => new Uint32Array(instance?.encodeIds(text));
} else {
const tokenizer = getTiktokenTokenizer(model);
encodeFunction = (tokenizer.encode.bind(tokenizer));
}
for (const entry of request.body) {
if (!entry || !entry.text) {
continue;

View File

@ -1,4 +1,5 @@
const fs = require('fs');
const path = require('path');
const { SentencePieceProcessor } = require("@agnai/sentencepiece-js");
const tiktoken = require('@dqbd/tiktoken');
const { Tokenizer } = require('@agnai/web-tokenizers');
@ -43,22 +44,39 @@ const TEXT_COMPLETION_MODELS = [
const CHARS_PER_TOKEN = 3.35;
let spp_llama;
let spp_nerd;
let spp_nerd_v2;
let spp_mistral;
let claude_tokenizer;
class SentencePieceTokenizer {
#instance;
#model;
async function loadSentencepieceTokenizer(modelPath) {
try {
const spp = new SentencePieceProcessor();
await spp.load(modelPath);
return spp;
} catch (error) {
console.error("Sentencepiece tokenizer failed to load: " + modelPath, error);
return null;
constructor(model) {
this.#model = model;
}
};
/**
* Gets the Sentencepiece tokenizer instance.
*/
async get() {
if (this.#instance) {
return this.#instance;
}
try {
this.#instance = new SentencePieceProcessor();
await this.#instance.load(this.#model);
console.log('Instantiated the tokenizer for', path.parse(this.#model).name);
return this.#instance;
} catch (error) {
console.error("Sentencepiece tokenizer failed to load: " + this.#model, error);
return null;
}
}
}
const spp_llama = new SentencePieceTokenizer('src/sentencepiece/llama.model');
const spp_nerd = new SentencePieceTokenizer('src/sentencepiece/nerdstash.model');
const spp_nerd_v2 = new SentencePieceTokenizer('src/sentencepiece/nerdstash_v2.model');
const spp_mistral = new SentencePieceTokenizer('src/sentencepiece/mistral.model');
let claude_tokenizer;
const sentencepieceTokenizers = [
'llama',
@ -70,7 +88,7 @@ const sentencepieceTokenizers = [
/**
* Gets the Sentencepiece tokenizer by the model name.
* @param {string} model Sentencepiece model name
* @returns {*} Sentencepiece tokenizer
* @returns {SentencePieceTokenizer|null} Sentencepiece tokenizer
*/
function getSentencepiceTokenizer(model) {
if (model.includes('llama')) {
@ -88,11 +106,21 @@ function getSentencepiceTokenizer(model) {
if (model.includes('nerdstash_v2')) {
return spp_nerd_v2;
}
return null;
}
async function countSentencepieceTokens(spp, text) {
/**
* Counts the token ids for the given text using the Sentencepiece tokenizer.
* @param {SentencePieceTokenizer} tokenizer Sentencepiece tokenizer
* @param {string} text Text to tokenize
* @returns { Promise<{ids: number[], count: number}> } Tokenization result
*/
async function countSentencepieceTokens(tokenizer, text) {
const instance = await tokenizer?.get();
// Fallback to strlen estimation
if (!spp) {
if (!instance) {
return {
ids: [],
count: Math.ceil(text.length / CHARS_PER_TOKEN)
@ -101,13 +129,19 @@ async function countSentencepieceTokens(spp, text) {
let cleaned = text; // cleanText(text); <-- cleaning text can result in an incorrect tokenization
let ids = spp.encodeIds(cleaned);
let ids = instance.encodeIds(cleaned);
return {
ids,
count: ids.length
};
}
/**
* Counts the tokens in the given array of objects using the Sentencepiece tokenizer.
* @param {SentencePieceTokenizer} tokenizer
* @param {object[]} array Array of objects to tokenize
* @returns {Promise<number>} Number of tokens
*/
async function countSentencepieceArrayTokens(tokenizer, array) {
const jsonBody = array.flatMap(x => Object.values(x)).join('\n\n');
const result = await countSentencepieceTokens(tokenizer, jsonBody);
@ -219,10 +253,10 @@ function countClaudeTokens(tokenizer, messages) {
/**
* Creates an API handler for encoding Sentencepiece tokens.
* @param {function} getTokenizerFn Tokenizer provider function
* @param {SentencePieceTokenizer} tokenizer Sentencepiece tokenizer
* @returns {any} Handler function
*/
function createSentencepieceEncodingHandler(getTokenizerFn) {
function createSentencepieceEncodingHandler(tokenizer) {
return async function (request, response) {
try {
if (!request.body) {
@ -230,9 +264,9 @@ function createSentencepieceEncodingHandler(getTokenizerFn) {
}
const text = request.body.text || '';
const tokenizer = getTokenizerFn();
const instance = await tokenizer?.get();
const { ids, count } = await countSentencepieceTokens(tokenizer, text);
const chunks = await tokenizer.encodePieces(text);
const chunks = await instance?.encodePieces(text);
return response.send({ ids, count, chunks });
} catch (error) {
console.log(error);
@ -243,10 +277,10 @@ function createSentencepieceEncodingHandler(getTokenizerFn) {
/**
* Creates an API handler for decoding Sentencepiece tokens.
* @param {function} getTokenizerFn Tokenizer provider function
* @param {SentencePieceTokenizer} tokenizer Sentencepiece tokenizer
* @returns {any} Handler function
*/
function createSentencepieceDecodingHandler(getTokenizerFn) {
function createSentencepieceDecodingHandler(tokenizer) {
return async function (request, response) {
try {
if (!request.body) {
@ -254,8 +288,8 @@ function createSentencepieceDecodingHandler(getTokenizerFn) {
}
const ids = request.body.ids || [];
const tokenizer = getTokenizerFn();
const text = await tokenizer.decodeIds(ids);
const instance = await tokenizer?.get();
const text = await instance?.decodeIds(ids);
return response.send({ text });
} catch (error) {
console.log(error);
@ -317,13 +351,7 @@ function createTiktokenDecodingHandler(modelId) {
* @returns {Promise<void>} Promise that resolves when the tokenizers are loaded
*/
async function loadTokenizers() {
[spp_llama, spp_nerd, spp_nerd_v2, spp_mistral, claude_tokenizer] = await Promise.all([
loadSentencepieceTokenizer('src/sentencepiece/llama.model'),
loadSentencepieceTokenizer('src/sentencepiece/nerdstash.model'),
loadSentencepieceTokenizer('src/sentencepiece/nerdstash_v2.model'),
loadSentencepieceTokenizer('src/sentencepiece/mistral.model'),
loadClaudeTokenizer('src/claude.json'),
]);
claude_tokenizer = await loadClaudeTokenizer('src/claude.json');
}
/**
@ -354,15 +382,15 @@ function registerEndpoints(app, jsonParser) {
}
});
app.post("/api/tokenize/llama", jsonParser, createSentencepieceEncodingHandler(() => spp_llama));
app.post("/api/tokenize/nerdstash", jsonParser, createSentencepieceEncodingHandler(() => spp_nerd));
app.post("/api/tokenize/nerdstash_v2", jsonParser, createSentencepieceEncodingHandler(() => spp_nerd_v2));
app.post("/api/tokenize/mistral", jsonParser, createSentencepieceEncodingHandler(() => spp_mistral));
app.post("/api/tokenize/llama", jsonParser, createSentencepieceEncodingHandler(spp_llama));
app.post("/api/tokenize/nerdstash", jsonParser, createSentencepieceEncodingHandler(spp_nerd));
app.post("/api/tokenize/nerdstash_v2", jsonParser, createSentencepieceEncodingHandler(spp_nerd_v2));
app.post("/api/tokenize/mistral", jsonParser, createSentencepieceEncodingHandler(spp_mistral));
app.post("/api/tokenize/gpt2", jsonParser, createTiktokenEncodingHandler('gpt2'));
app.post("/api/decode/llama", jsonParser, createSentencepieceDecodingHandler(() => spp_llama));
app.post("/api/decode/nerdstash", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd));
app.post("/api/decode/nerdstash_v2", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd_v2));
app.post("/api/decode/mistral", jsonParser, createSentencepieceDecodingHandler(() => spp_mistral));
app.post("/api/decode/llama", jsonParser, createSentencepieceDecodingHandler(spp_llama));
app.post("/api/decode/nerdstash", jsonParser, createSentencepieceDecodingHandler(spp_nerd));
app.post("/api/decode/nerdstash_v2", jsonParser, createSentencepieceDecodingHandler(spp_nerd_v2));
app.post("/api/decode/mistral", jsonParser, createSentencepieceDecodingHandler(spp_mistral));
app.post("/api/decode/gpt2", jsonParser, createTiktokenDecodingHandler('gpt2'));
app.post("/api/tokenize/openai-encode", jsonParser, async function (req, res) {
@ -370,12 +398,12 @@ function registerEndpoints(app, jsonParser) {
const queryModel = String(req.query.model || '');
if (queryModel.includes('llama')) {
const handler = createSentencepieceEncodingHandler(() => spp_llama);
const handler = createSentencepieceEncodingHandler(spp_llama);
return handler(req, res);
}
if (queryModel.includes('mistral')) {
const handler = createSentencepieceEncodingHandler(() => spp_mistral);
const handler = createSentencepieceEncodingHandler(spp_mistral);
return handler(req, res);
}
@ -462,9 +490,6 @@ module.exports = {
TEXT_COMPLETION_MODELS,
getTokenizerModel,
getTiktokenTokenizer,
loadSentencepieceTokenizer,
loadClaudeTokenizer,
countSentencepieceTokens,
countClaudeTokens,
loadTokenizers,
registerEndpoints,