diff --git a/server.js b/server.js index 86b48e57..0d5f7903 100644 --- a/server.js +++ b/server.js @@ -45,7 +45,6 @@ const { forwardFetchResponse, } = require('./src/util'); const { ensureThumbnailCache } = require('./src/endpoints/thumbnails'); -const { loadTokenizers } = require('./src/endpoints/tokenizers'); // Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0. // https://github.com/nodejs/node/issues/47822#issuecomment-1564708870 @@ -548,7 +547,6 @@ const setupTasks = async function () { await ensureThumbnailCache(); cleanUploads(); - await loadTokenizers(); await settingsEndpoint.init(); await statsEndpoint.init(); diff --git a/src/endpoints/tokenizers.js b/src/endpoints/tokenizers.js index d7dddf5f..65e79529 100644 --- a/src/endpoints/tokenizers.js +++ b/src/endpoints/tokenizers.js @@ -10,6 +10,10 @@ const { TEXTGEN_TYPES } = require('../constants'); const { jsonParser } = require('../express-common'); const { setAdditionalHeaders } = require('../additional-headers'); +/** + * @typedef { (req: import('express').Request, res: import('express').Response) => Promise } TokenizationHandler + */ + /** * @type {{[key: string]: import("@dqbd/tiktoken").Tiktoken}} Tokenizers cache */ @@ -48,16 +52,30 @@ const TEXT_COMPLETION_MODELS = [ const CHARS_PER_TOKEN = 3.35; +/** + * Sentencepiece tokenizer for tokenizing text. + */ class SentencePieceTokenizer { + /** + * @type {import('@agnai/sentencepiece-js').SentencePieceProcessor} Sentencepiece tokenizer instance + */ #instance; + /** + * @type {string} Path to the tokenizer model + */ #model; + /** + * Creates a new Sentencepiece tokenizer. + * @param {string} model Path to the tokenizer model + */ constructor(model) { this.#model = model; } /** * Gets the Sentencepiece tokenizer instance. + * @returns {Promise} Sentencepiece tokenizer instance */ async get() { if (this.#instance) { @@ -76,18 +94,61 @@ class SentencePieceTokenizer { } } -const spp_llama = new SentencePieceTokenizer('src/sentencepiece/llama.model'); -const spp_nerd = new SentencePieceTokenizer('src/sentencepiece/nerdstash.model'); -const spp_nerd_v2 = new SentencePieceTokenizer('src/sentencepiece/nerdstash_v2.model'); -const spp_mistral = new SentencePieceTokenizer('src/sentencepiece/mistral.model'); -const spp_yi = new SentencePieceTokenizer('src/sentencepiece/yi.model'); -let claude_tokenizer; +/** + * Web tokenizer for tokenizing text. + */ +class WebTokenizer { + /** + * @type {Tokenizer} Web tokenizer instance + */ + #instance; + /** + * @type {string} Path to the tokenizer model + */ + #model; + + /** + * Creates a new Web tokenizer. + * @param {string} model Path to the tokenizer model + */ + constructor(model) { + this.#model = model; + } + + /** + * Gets the Web tokenizer instance. + * @returns {Promise} Web tokenizer instance + */ + async get() { + if (this.#instance) { + return this.#instance; + } + + try { + const arrayBuffer = fs.readFileSync(this.#model).buffer; + this.#instance = await Tokenizer.fromJSON(arrayBuffer); + console.log('Instantiated the tokenizer for', path.parse(this.#model).name); + return this.#instance; + } catch (error) { + console.error('Web tokenizer failed to load: ' + this.#model, error); + return null; + } + } +} + +const spp_llama = new SentencePieceTokenizer('src/tokenizers/llama.model'); +const spp_nerd = new SentencePieceTokenizer('src/tokenizers/nerdstash.model'); +const spp_nerd_v2 = new SentencePieceTokenizer('src/tokenizers/nerdstash_v2.model'); +const spp_mistral = new SentencePieceTokenizer('src/tokenizers/mistral.model'); +const spp_yi = new SentencePieceTokenizer('src/tokenizers/yi.model'); +const claude_tokenizer = new WebTokenizer('src/tokenizers/claude.json'); const sentencepieceTokenizers = [ 'llama', 'nerdstash', 'nerdstash_v2', 'mistral', + 'yi', ]; /** @@ -112,6 +173,10 @@ function getSentencepiceTokenizer(model) { return spp_nerd_v2; } + if (model.includes('yi')) { + return spp_yi; + } + return null; } @@ -168,13 +233,23 @@ async function getTiktokenChunks(tokenizer, ids) { return chunks; } -async function getWebTokenizersChunks(tokenizer, ids) { +/** + * Gets the token chunks for the given token IDs using the Web tokenizer. + * @param {Tokenizer} tokenizer Web tokenizer instance + * @param {number[]} ids Token IDs + * @returns {string[]} Token chunks + */ +function getWebTokenizersChunks(tokenizer, ids) { const chunks = []; - for (let i = 0; i < ids.length; i++) { - const id = ids[i]; - const chunkText = await tokenizer.decode(new Uint32Array([id])); + for (let i = 0, lastProcessed = 0; i < ids.length; i++) { + const chunkIds = ids.slice(lastProcessed, i + 1); + const chunkText = tokenizer.decode(new Int32Array(chunkIds)); + if (chunkText === '�') { + continue; + } chunks.push(chunkText); + lastProcessed = i + 1; } return chunks; @@ -237,17 +312,12 @@ function getTiktokenTokenizer(model) { return tokenizer; } -async function loadClaudeTokenizer(modelPath) { - try { - const arrayBuffer = fs.readFileSync(modelPath).buffer; - const instance = await Tokenizer.fromJSON(arrayBuffer); - return instance; - } catch (error) { - console.error('Claude tokenizer failed to load: ' + modelPath, error); - return null; - } -} - +/** + * Counts the tokens for the given messages using the Claude tokenizer. + * @param {Tokenizer} tokenizer Web tokenizer + * @param {object[]} messages Array of messages + * @returns {number} Number of tokens + */ function countClaudeTokens(tokenizer, messages) { // Should be fine if we use the old conversion method instead of the messages API one i think? const convertedPrompt = convertClaudePrompt(messages, false, '', false, false, '', false); @@ -264,9 +334,14 @@ function countClaudeTokens(tokenizer, messages) { /** * Creates an API handler for encoding Sentencepiece tokens. * @param {SentencePieceTokenizer} tokenizer Sentencepiece tokenizer - * @returns {any} Handler function + * @returns {TokenizationHandler} Handler function */ function createSentencepieceEncodingHandler(tokenizer) { + /** + * Request handler for encoding Sentencepiece tokens. + * @param {import('express').Request} request + * @param {import('express').Response} response + */ return async function (request, response) { try { if (!request.body) { @@ -276,7 +351,7 @@ function createSentencepieceEncodingHandler(tokenizer) { const text = request.body.text || ''; const instance = await tokenizer?.get(); const { ids, count } = await countSentencepieceTokens(tokenizer, text); - const chunks = await instance?.encodePieces(text); + const chunks = instance?.encodePieces(text); return response.send({ ids, count, chunks }); } catch (error) { console.log(error); @@ -288,9 +363,14 @@ function createSentencepieceEncodingHandler(tokenizer) { /** * Creates an API handler for decoding Sentencepiece tokens. * @param {SentencePieceTokenizer} tokenizer Sentencepiece tokenizer - * @returns {any} Handler function + * @returns {TokenizationHandler} Handler function */ function createSentencepieceDecodingHandler(tokenizer) { + /** + * Request handler for decoding Sentencepiece tokens. + * @param {import('express').Request} request + * @param {import('express').Response} response + */ return async function (request, response) { try { if (!request.body) { @@ -299,6 +379,7 @@ function createSentencepieceDecodingHandler(tokenizer) { const ids = request.body.ids || []; const instance = await tokenizer?.get(); + if (!instance) throw new Error('Failed to load the Sentencepiece tokenizer'); const ops = ids.map(id => instance.decodeIds([id])); const chunks = await Promise.all(ops); const text = chunks.join(''); @@ -313,9 +394,14 @@ function createSentencepieceDecodingHandler(tokenizer) { /** * Creates an API handler for encoding Tiktoken tokens. * @param {string} modelId Tiktoken model ID - * @returns {any} Handler function + * @returns {TokenizationHandler} Handler function */ function createTiktokenEncodingHandler(modelId) { + /** + * Request handler for encoding Tiktoken tokens. + * @param {import('express').Request} request + * @param {import('express').Response} response + */ return async function (request, response) { try { if (!request.body) { @@ -337,9 +423,14 @@ function createTiktokenEncodingHandler(modelId) { /** * Creates an API handler for decoding Tiktoken tokens. * @param {string} modelId Tiktoken model ID - * @returns {any} Handler function + * @returns {TokenizationHandler} Handler function */ function createTiktokenDecodingHandler(modelId) { + /** + * Request handler for decoding Tiktoken tokens. + * @param {import('express').Request} request + * @param {import('express').Response} response + */ return async function (request, response) { try { if (!request.body) { @@ -358,14 +449,6 @@ function createTiktokenDecodingHandler(modelId) { }; } -/** - * Loads the model tokenizers. - * @returns {Promise} Promise that resolves when the tokenizers are loaded - */ -async function loadTokenizers() { - claude_tokenizer = await loadClaudeTokenizer('src/claude.json'); -} - const router = express.Router(); router.post('/ai21/count', jsonParser, async function (req, res) { @@ -446,8 +529,10 @@ router.post('/openai/encode', jsonParser, async function (req, res) { if (queryModel.includes('claude')) { const text = req.body.text || ''; - const tokens = Object.values(claude_tokenizer.encode(text)); - const chunks = await getWebTokenizersChunks(claude_tokenizer, tokens); + const instance = await claude_tokenizer.get(); + if (!instance) throw new Error('Failed to load the Claude tokenizer'); + const tokens = Object.values(instance.encode(text)); + const chunks = getWebTokenizersChunks(instance, tokens); return res.send({ ids: tokens, count: tokens.length, chunks }); } @@ -481,7 +566,9 @@ router.post('/openai/decode', jsonParser, async function (req, res) { if (queryModel.includes('claude')) { const ids = req.body.ids || []; - const chunkText = await claude_tokenizer.decode(new Uint32Array(ids)); + const instance = await claude_tokenizer.get(); + if (!instance) throw new Error('Failed to load the Claude tokenizer'); + const chunkText = instance.decode(new Int32Array(ids)); return res.send({ text: chunkText }); } @@ -503,7 +590,9 @@ router.post('/openai/count', jsonParser, async function (req, res) { const model = getTokenizerModel(queryModel); if (model === 'claude') { - num_tokens = countClaudeTokens(claude_tokenizer, req.body); + const instance = await claude_tokenizer.get(); + if (!instance) throw new Error('Failed to load the Claude tokenizer'); + num_tokens = countClaudeTokens(instance, req.body); return res.send({ 'token_count': num_tokens }); } @@ -665,7 +754,6 @@ module.exports = { getTokenizerModel, getTiktokenTokenizer, countClaudeTokens, - loadTokenizers, getSentencepiceTokenizer, sentencepieceTokenizers, router, diff --git a/src/claude.json b/src/tokenizers/claude.json similarity index 100% rename from src/claude.json rename to src/tokenizers/claude.json diff --git a/src/sentencepiece/llama.model b/src/tokenizers/llama.model similarity index 100% rename from src/sentencepiece/llama.model rename to src/tokenizers/llama.model diff --git a/src/sentencepiece/mistral.model b/src/tokenizers/mistral.model similarity index 100% rename from src/sentencepiece/mistral.model rename to src/tokenizers/mistral.model diff --git a/src/sentencepiece/nerdstash.model b/src/tokenizers/nerdstash.model similarity index 100% rename from src/sentencepiece/nerdstash.model rename to src/tokenizers/nerdstash.model diff --git a/src/sentencepiece/nerdstash_v2.model b/src/tokenizers/nerdstash_v2.model similarity index 100% rename from src/sentencepiece/nerdstash_v2.model rename to src/tokenizers/nerdstash_v2.model diff --git a/src/sentencepiece/yi.model b/src/tokenizers/yi.model similarity index 100% rename from src/sentencepiece/yi.model rename to src/tokenizers/yi.model