Add lazy loading of sentencepiece tokenizers

This commit is contained in:
Cohee 2023-11-15 19:39:55 +02:00
parent 9199750afe
commit 3fb26d3927
2 changed files with 73 additions and 48 deletions

View File

@ -2795,13 +2795,13 @@ app.post("/openai_bias", jsonParser, async function (request, response) {
if (sentencepieceTokenizers.includes(model)) { if (sentencepieceTokenizers.includes(model)) {
const tokenizer = getSentencepiceTokenizer(model); const tokenizer = getSentencepiceTokenizer(model);
encodeFunction = (text) => new Uint32Array(tokenizer.encodeIds(text)); const instance = await tokenizer?.get();
encodeFunction = (text) => new Uint32Array(instance?.encodeIds(text));
} else { } else {
const tokenizer = getTiktokenTokenizer(model); const tokenizer = getTiktokenTokenizer(model);
encodeFunction = (tokenizer.encode.bind(tokenizer)); encodeFunction = (tokenizer.encode.bind(tokenizer));
} }
for (const entry of request.body) { for (const entry of request.body) {
if (!entry || !entry.text) { if (!entry || !entry.text) {
continue; continue;

View File

@ -1,4 +1,5 @@
const fs = require('fs'); const fs = require('fs');
const path = require('path');
const { SentencePieceProcessor } = require("@agnai/sentencepiece-js"); const { SentencePieceProcessor } = require("@agnai/sentencepiece-js");
const tiktoken = require('@dqbd/tiktoken'); const tiktoken = require('@dqbd/tiktoken');
const { Tokenizer } = require('@agnai/web-tokenizers'); const { Tokenizer } = require('@agnai/web-tokenizers');
@ -43,22 +44,39 @@ const TEXT_COMPLETION_MODELS = [
const CHARS_PER_TOKEN = 3.35; const CHARS_PER_TOKEN = 3.35;
let spp_llama; class SentencePieceTokenizer {
let spp_nerd; #instance;
let spp_nerd_v2; #model;
let spp_mistral;
let claude_tokenizer; constructor(model) {
this.#model = model;
}
/**
* Gets the Sentencepiece tokenizer instance.
*/
async get() {
if (this.#instance) {
return this.#instance;
}
async function loadSentencepieceTokenizer(modelPath) {
try { try {
const spp = new SentencePieceProcessor(); this.#instance = new SentencePieceProcessor();
await spp.load(modelPath); await this.#instance.load(this.#model);
return spp; console.log('Instantiated the tokenizer for', path.parse(this.#model).name);
return this.#instance;
} catch (error) { } catch (error) {
console.error("Sentencepiece tokenizer failed to load: " + modelPath, error); console.error("Sentencepiece tokenizer failed to load: " + this.#model, error);
return null; return null;
} }
}; }
}
const spp_llama = new SentencePieceTokenizer('src/sentencepiece/llama.model');
const spp_nerd = new SentencePieceTokenizer('src/sentencepiece/nerdstash.model');
const spp_nerd_v2 = new SentencePieceTokenizer('src/sentencepiece/nerdstash_v2.model');
const spp_mistral = new SentencePieceTokenizer('src/sentencepiece/mistral.model');
let claude_tokenizer;
const sentencepieceTokenizers = [ const sentencepieceTokenizers = [
'llama', 'llama',
@ -70,7 +88,7 @@ const sentencepieceTokenizers = [
/** /**
* Gets the Sentencepiece tokenizer by the model name. * Gets the Sentencepiece tokenizer by the model name.
* @param {string} model Sentencepiece model name * @param {string} model Sentencepiece model name
* @returns {*} Sentencepiece tokenizer * @returns {SentencePieceTokenizer|null} Sentencepiece tokenizer
*/ */
function getSentencepiceTokenizer(model) { function getSentencepiceTokenizer(model) {
if (model.includes('llama')) { if (model.includes('llama')) {
@ -88,11 +106,21 @@ function getSentencepiceTokenizer(model) {
if (model.includes('nerdstash_v2')) { if (model.includes('nerdstash_v2')) {
return spp_nerd_v2; return spp_nerd_v2;
} }
return null;
} }
async function countSentencepieceTokens(spp, text) { /**
* Counts the token ids for the given text using the Sentencepiece tokenizer.
* @param {SentencePieceTokenizer} tokenizer Sentencepiece tokenizer
* @param {string} text Text to tokenize
* @returns { Promise<{ids: number[], count: number}> } Tokenization result
*/
async function countSentencepieceTokens(tokenizer, text) {
const instance = await tokenizer?.get();
// Fallback to strlen estimation // Fallback to strlen estimation
if (!spp) { if (!instance) {
return { return {
ids: [], ids: [],
count: Math.ceil(text.length / CHARS_PER_TOKEN) count: Math.ceil(text.length / CHARS_PER_TOKEN)
@ -101,13 +129,19 @@ async function countSentencepieceTokens(spp, text) {
let cleaned = text; // cleanText(text); <-- cleaning text can result in an incorrect tokenization let cleaned = text; // cleanText(text); <-- cleaning text can result in an incorrect tokenization
let ids = spp.encodeIds(cleaned); let ids = instance.encodeIds(cleaned);
return { return {
ids, ids,
count: ids.length count: ids.length
}; };
} }
/**
* Counts the tokens in the given array of objects using the Sentencepiece tokenizer.
* @param {SentencePieceTokenizer} tokenizer
* @param {object[]} array Array of objects to tokenize
* @returns {Promise<number>} Number of tokens
*/
async function countSentencepieceArrayTokens(tokenizer, array) { async function countSentencepieceArrayTokens(tokenizer, array) {
const jsonBody = array.flatMap(x => Object.values(x)).join('\n\n'); const jsonBody = array.flatMap(x => Object.values(x)).join('\n\n');
const result = await countSentencepieceTokens(tokenizer, jsonBody); const result = await countSentencepieceTokens(tokenizer, jsonBody);
@ -219,10 +253,10 @@ function countClaudeTokens(tokenizer, messages) {
/** /**
* Creates an API handler for encoding Sentencepiece tokens. * Creates an API handler for encoding Sentencepiece tokens.
* @param {function} getTokenizerFn Tokenizer provider function * @param {SentencePieceTokenizer} tokenizer Sentencepiece tokenizer
* @returns {any} Handler function * @returns {any} Handler function
*/ */
function createSentencepieceEncodingHandler(getTokenizerFn) { function createSentencepieceEncodingHandler(tokenizer) {
return async function (request, response) { return async function (request, response) {
try { try {
if (!request.body) { if (!request.body) {
@ -230,9 +264,9 @@ function createSentencepieceEncodingHandler(getTokenizerFn) {
} }
const text = request.body.text || ''; const text = request.body.text || '';
const tokenizer = getTokenizerFn(); const instance = await tokenizer?.get();
const { ids, count } = await countSentencepieceTokens(tokenizer, text); const { ids, count } = await countSentencepieceTokens(tokenizer, text);
const chunks = await tokenizer.encodePieces(text); const chunks = await instance?.encodePieces(text);
return response.send({ ids, count, chunks }); return response.send({ ids, count, chunks });
} catch (error) { } catch (error) {
console.log(error); console.log(error);
@ -243,10 +277,10 @@ function createSentencepieceEncodingHandler(getTokenizerFn) {
/** /**
* Creates an API handler for decoding Sentencepiece tokens. * Creates an API handler for decoding Sentencepiece tokens.
* @param {function} getTokenizerFn Tokenizer provider function * @param {SentencePieceTokenizer} tokenizer Sentencepiece tokenizer
* @returns {any} Handler function * @returns {any} Handler function
*/ */
function createSentencepieceDecodingHandler(getTokenizerFn) { function createSentencepieceDecodingHandler(tokenizer) {
return async function (request, response) { return async function (request, response) {
try { try {
if (!request.body) { if (!request.body) {
@ -254,8 +288,8 @@ function createSentencepieceDecodingHandler(getTokenizerFn) {
} }
const ids = request.body.ids || []; const ids = request.body.ids || [];
const tokenizer = getTokenizerFn(); const instance = await tokenizer?.get();
const text = await tokenizer.decodeIds(ids); const text = await instance?.decodeIds(ids);
return response.send({ text }); return response.send({ text });
} catch (error) { } catch (error) {
console.log(error); console.log(error);
@ -317,13 +351,7 @@ function createTiktokenDecodingHandler(modelId) {
* @returns {Promise<void>} Promise that resolves when the tokenizers are loaded * @returns {Promise<void>} Promise that resolves when the tokenizers are loaded
*/ */
async function loadTokenizers() { async function loadTokenizers() {
[spp_llama, spp_nerd, spp_nerd_v2, spp_mistral, claude_tokenizer] = await Promise.all([ claude_tokenizer = await loadClaudeTokenizer('src/claude.json');
loadSentencepieceTokenizer('src/sentencepiece/llama.model'),
loadSentencepieceTokenizer('src/sentencepiece/nerdstash.model'),
loadSentencepieceTokenizer('src/sentencepiece/nerdstash_v2.model'),
loadSentencepieceTokenizer('src/sentencepiece/mistral.model'),
loadClaudeTokenizer('src/claude.json'),
]);
} }
/** /**
@ -354,15 +382,15 @@ function registerEndpoints(app, jsonParser) {
} }
}); });
app.post("/api/tokenize/llama", jsonParser, createSentencepieceEncodingHandler(() => spp_llama)); app.post("/api/tokenize/llama", jsonParser, createSentencepieceEncodingHandler(spp_llama));
app.post("/api/tokenize/nerdstash", jsonParser, createSentencepieceEncodingHandler(() => spp_nerd)); app.post("/api/tokenize/nerdstash", jsonParser, createSentencepieceEncodingHandler(spp_nerd));
app.post("/api/tokenize/nerdstash_v2", jsonParser, createSentencepieceEncodingHandler(() => spp_nerd_v2)); app.post("/api/tokenize/nerdstash_v2", jsonParser, createSentencepieceEncodingHandler(spp_nerd_v2));
app.post("/api/tokenize/mistral", jsonParser, createSentencepieceEncodingHandler(() => spp_mistral)); app.post("/api/tokenize/mistral", jsonParser, createSentencepieceEncodingHandler(spp_mistral));
app.post("/api/tokenize/gpt2", jsonParser, createTiktokenEncodingHandler('gpt2')); app.post("/api/tokenize/gpt2", jsonParser, createTiktokenEncodingHandler('gpt2'));
app.post("/api/decode/llama", jsonParser, createSentencepieceDecodingHandler(() => spp_llama)); app.post("/api/decode/llama", jsonParser, createSentencepieceDecodingHandler(spp_llama));
app.post("/api/decode/nerdstash", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd)); app.post("/api/decode/nerdstash", jsonParser, createSentencepieceDecodingHandler(spp_nerd));
app.post("/api/decode/nerdstash_v2", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd_v2)); app.post("/api/decode/nerdstash_v2", jsonParser, createSentencepieceDecodingHandler(spp_nerd_v2));
app.post("/api/decode/mistral", jsonParser, createSentencepieceDecodingHandler(() => spp_mistral)); app.post("/api/decode/mistral", jsonParser, createSentencepieceDecodingHandler(spp_mistral));
app.post("/api/decode/gpt2", jsonParser, createTiktokenDecodingHandler('gpt2')); app.post("/api/decode/gpt2", jsonParser, createTiktokenDecodingHandler('gpt2'));
app.post("/api/tokenize/openai-encode", jsonParser, async function (req, res) { app.post("/api/tokenize/openai-encode", jsonParser, async function (req, res) {
@ -370,12 +398,12 @@ function registerEndpoints(app, jsonParser) {
const queryModel = String(req.query.model || ''); const queryModel = String(req.query.model || '');
if (queryModel.includes('llama')) { if (queryModel.includes('llama')) {
const handler = createSentencepieceEncodingHandler(() => spp_llama); const handler = createSentencepieceEncodingHandler(spp_llama);
return handler(req, res); return handler(req, res);
} }
if (queryModel.includes('mistral')) { if (queryModel.includes('mistral')) {
const handler = createSentencepieceEncodingHandler(() => spp_mistral); const handler = createSentencepieceEncodingHandler(spp_mistral);
return handler(req, res); return handler(req, res);
} }
@ -462,9 +490,6 @@ module.exports = {
TEXT_COMPLETION_MODELS, TEXT_COMPLETION_MODELS,
getTokenizerModel, getTokenizerModel,
getTiktokenTokenizer, getTiktokenTokenizer,
loadSentencepieceTokenizer,
loadClaudeTokenizer,
countSentencepieceTokens,
countClaudeTokens, countClaudeTokens,
loadTokenizers, loadTokenizers,
registerEndpoints, registerEndpoints,