From 3fb26d3927289a1b5d30535c0383f0773aee94f7 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Wed, 15 Nov 2023 19:39:55 +0200 Subject: [PATCH] Add lazy loading of sentencepiece tokenizers --- server.js | 4 +- src/tokenizers.js | 117 ++++++++++++++++++++++++++++------------------ 2 files changed, 73 insertions(+), 48 deletions(-) diff --git a/server.js b/server.js index 7ff4de46b..d4708854b 100644 --- a/server.js +++ b/server.js @@ -2795,13 +2795,13 @@ app.post("/openai_bias", jsonParser, async function (request, response) { if (sentencepieceTokenizers.includes(model)) { const tokenizer = getSentencepiceTokenizer(model); - encodeFunction = (text) => new Uint32Array(tokenizer.encodeIds(text)); + const instance = await tokenizer?.get(); + encodeFunction = (text) => new Uint32Array(instance?.encodeIds(text)); } else { const tokenizer = getTiktokenTokenizer(model); encodeFunction = (tokenizer.encode.bind(tokenizer)); } - for (const entry of request.body) { if (!entry || !entry.text) { continue; diff --git a/src/tokenizers.js b/src/tokenizers.js index 771da69fe..72cb4101c 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -1,4 +1,5 @@ const fs = require('fs'); +const path = require('path'); const { SentencePieceProcessor } = require("@agnai/sentencepiece-js"); const tiktoken = require('@dqbd/tiktoken'); const { Tokenizer } = require('@agnai/web-tokenizers'); @@ -43,22 +44,39 @@ const TEXT_COMPLETION_MODELS = [ const CHARS_PER_TOKEN = 3.35; -let spp_llama; -let spp_nerd; -let spp_nerd_v2; -let spp_mistral; -let claude_tokenizer; +class SentencePieceTokenizer { + #instance; + #model; -async function loadSentencepieceTokenizer(modelPath) { - try { - const spp = new SentencePieceProcessor(); - await spp.load(modelPath); - return spp; - } catch (error) { - console.error("Sentencepiece tokenizer failed to load: " + modelPath, error); - return null; + constructor(model) { + this.#model = model; } -}; + + /** + * Gets the Sentencepiece tokenizer instance. + */ + async get() { + if (this.#instance) { + return this.#instance; + } + + try { + this.#instance = new SentencePieceProcessor(); + await this.#instance.load(this.#model); + console.log('Instantiated the tokenizer for', path.parse(this.#model).name); + return this.#instance; + } catch (error) { + console.error("Sentencepiece tokenizer failed to load: " + this.#model, error); + return null; + } + } +} + +const spp_llama = new SentencePieceTokenizer('src/sentencepiece/llama.model'); +const spp_nerd = new SentencePieceTokenizer('src/sentencepiece/nerdstash.model'); +const spp_nerd_v2 = new SentencePieceTokenizer('src/sentencepiece/nerdstash_v2.model'); +const spp_mistral = new SentencePieceTokenizer('src/sentencepiece/mistral.model'); +let claude_tokenizer; const sentencepieceTokenizers = [ 'llama', @@ -70,7 +88,7 @@ const sentencepieceTokenizers = [ /** * Gets the Sentencepiece tokenizer by the model name. * @param {string} model Sentencepiece model name - * @returns {*} Sentencepiece tokenizer + * @returns {SentencePieceTokenizer|null} Sentencepiece tokenizer */ function getSentencepiceTokenizer(model) { if (model.includes('llama')) { @@ -88,11 +106,21 @@ function getSentencepiceTokenizer(model) { if (model.includes('nerdstash_v2')) { return spp_nerd_v2; } + + return null; } -async function countSentencepieceTokens(spp, text) { +/** + * Counts the token ids for the given text using the Sentencepiece tokenizer. + * @param {SentencePieceTokenizer} tokenizer Sentencepiece tokenizer + * @param {string} text Text to tokenize + * @returns { Promise<{ids: number[], count: number}> } Tokenization result + */ +async function countSentencepieceTokens(tokenizer, text) { + const instance = await tokenizer?.get(); + // Fallback to strlen estimation - if (!spp) { + if (!instance) { return { ids: [], count: Math.ceil(text.length / CHARS_PER_TOKEN) @@ -101,13 +129,19 @@ async function countSentencepieceTokens(spp, text) { let cleaned = text; // cleanText(text); <-- cleaning text can result in an incorrect tokenization - let ids = spp.encodeIds(cleaned); + let ids = instance.encodeIds(cleaned); return { ids, count: ids.length }; } +/** + * Counts the tokens in the given array of objects using the Sentencepiece tokenizer. + * @param {SentencePieceTokenizer} tokenizer + * @param {object[]} array Array of objects to tokenize + * @returns {Promise} Number of tokens + */ async function countSentencepieceArrayTokens(tokenizer, array) { const jsonBody = array.flatMap(x => Object.values(x)).join('\n\n'); const result = await countSentencepieceTokens(tokenizer, jsonBody); @@ -219,10 +253,10 @@ function countClaudeTokens(tokenizer, messages) { /** * Creates an API handler for encoding Sentencepiece tokens. - * @param {function} getTokenizerFn Tokenizer provider function + * @param {SentencePieceTokenizer} tokenizer Sentencepiece tokenizer * @returns {any} Handler function */ -function createSentencepieceEncodingHandler(getTokenizerFn) { +function createSentencepieceEncodingHandler(tokenizer) { return async function (request, response) { try { if (!request.body) { @@ -230,9 +264,9 @@ function createSentencepieceEncodingHandler(getTokenizerFn) { } const text = request.body.text || ''; - const tokenizer = getTokenizerFn(); + const instance = await tokenizer?.get(); const { ids, count } = await countSentencepieceTokens(tokenizer, text); - const chunks = await tokenizer.encodePieces(text); + const chunks = await instance?.encodePieces(text); return response.send({ ids, count, chunks }); } catch (error) { console.log(error); @@ -243,10 +277,10 @@ function createSentencepieceEncodingHandler(getTokenizerFn) { /** * Creates an API handler for decoding Sentencepiece tokens. - * @param {function} getTokenizerFn Tokenizer provider function + * @param {SentencePieceTokenizer} tokenizer Sentencepiece tokenizer * @returns {any} Handler function */ -function createSentencepieceDecodingHandler(getTokenizerFn) { +function createSentencepieceDecodingHandler(tokenizer) { return async function (request, response) { try { if (!request.body) { @@ -254,8 +288,8 @@ function createSentencepieceDecodingHandler(getTokenizerFn) { } const ids = request.body.ids || []; - const tokenizer = getTokenizerFn(); - const text = await tokenizer.decodeIds(ids); + const instance = await tokenizer?.get(); + const text = await instance?.decodeIds(ids); return response.send({ text }); } catch (error) { console.log(error); @@ -317,13 +351,7 @@ function createTiktokenDecodingHandler(modelId) { * @returns {Promise} Promise that resolves when the tokenizers are loaded */ async function loadTokenizers() { - [spp_llama, spp_nerd, spp_nerd_v2, spp_mistral, claude_tokenizer] = await Promise.all([ - loadSentencepieceTokenizer('src/sentencepiece/llama.model'), - loadSentencepieceTokenizer('src/sentencepiece/nerdstash.model'), - loadSentencepieceTokenizer('src/sentencepiece/nerdstash_v2.model'), - loadSentencepieceTokenizer('src/sentencepiece/mistral.model'), - loadClaudeTokenizer('src/claude.json'), - ]); + claude_tokenizer = await loadClaudeTokenizer('src/claude.json'); } /** @@ -354,15 +382,15 @@ function registerEndpoints(app, jsonParser) { } }); - app.post("/api/tokenize/llama", jsonParser, createSentencepieceEncodingHandler(() => spp_llama)); - app.post("/api/tokenize/nerdstash", jsonParser, createSentencepieceEncodingHandler(() => spp_nerd)); - app.post("/api/tokenize/nerdstash_v2", jsonParser, createSentencepieceEncodingHandler(() => spp_nerd_v2)); - app.post("/api/tokenize/mistral", jsonParser, createSentencepieceEncodingHandler(() => spp_mistral)); + app.post("/api/tokenize/llama", jsonParser, createSentencepieceEncodingHandler(spp_llama)); + app.post("/api/tokenize/nerdstash", jsonParser, createSentencepieceEncodingHandler(spp_nerd)); + app.post("/api/tokenize/nerdstash_v2", jsonParser, createSentencepieceEncodingHandler(spp_nerd_v2)); + app.post("/api/tokenize/mistral", jsonParser, createSentencepieceEncodingHandler(spp_mistral)); app.post("/api/tokenize/gpt2", jsonParser, createTiktokenEncodingHandler('gpt2')); - app.post("/api/decode/llama", jsonParser, createSentencepieceDecodingHandler(() => spp_llama)); - app.post("/api/decode/nerdstash", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd)); - app.post("/api/decode/nerdstash_v2", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd_v2)); - app.post("/api/decode/mistral", jsonParser, createSentencepieceDecodingHandler(() => spp_mistral)); + app.post("/api/decode/llama", jsonParser, createSentencepieceDecodingHandler(spp_llama)); + app.post("/api/decode/nerdstash", jsonParser, createSentencepieceDecodingHandler(spp_nerd)); + app.post("/api/decode/nerdstash_v2", jsonParser, createSentencepieceDecodingHandler(spp_nerd_v2)); + app.post("/api/decode/mistral", jsonParser, createSentencepieceDecodingHandler(spp_mistral)); app.post("/api/decode/gpt2", jsonParser, createTiktokenDecodingHandler('gpt2')); app.post("/api/tokenize/openai-encode", jsonParser, async function (req, res) { @@ -370,12 +398,12 @@ function registerEndpoints(app, jsonParser) { const queryModel = String(req.query.model || ''); if (queryModel.includes('llama')) { - const handler = createSentencepieceEncodingHandler(() => spp_llama); + const handler = createSentencepieceEncodingHandler(spp_llama); return handler(req, res); } if (queryModel.includes('mistral')) { - const handler = createSentencepieceEncodingHandler(() => spp_mistral); + const handler = createSentencepieceEncodingHandler(spp_mistral); return handler(req, res); } @@ -462,9 +490,6 @@ module.exports = { TEXT_COMPLETION_MODELS, getTokenizerModel, getTiktokenTokenizer, - loadSentencepieceTokenizer, - loadClaudeTokenizer, - countSentencepieceTokens, countClaudeTokens, loadTokenizers, registerEndpoints,