const fs = require('fs'); const { SentencePieceProcessor } = require("@agnai/sentencepiece-js"); const tiktoken = require('@dqbd/tiktoken'); const { Tokenizer } = require('@agnai/web-tokenizers'); const { convertClaudePrompt } = require('./chat-completion'); const { readSecret, SECRET_KEYS } = require('./secrets'); /** * @type {{[key: string]: import("@dqbd/tiktoken").Tiktoken}} Tokenizers cache */ const tokenizersCache = {}; /** * @type {string[]} */ const TEXT_COMPLETION_MODELS = [ "gpt-3.5-turbo-instruct", "gpt-3.5-turbo-instruct-0914", "text-davinci-003", "text-davinci-002", "text-davinci-001", "text-curie-001", "text-babbage-001", "text-ada-001", "code-davinci-002", "code-davinci-001", "code-cushman-002", "code-cushman-001", "text-davinci-edit-001", "code-davinci-edit-001", "text-embedding-ada-002", "text-similarity-davinci-001", "text-similarity-curie-001", "text-similarity-babbage-001", "text-similarity-ada-001", "text-search-davinci-doc-001", "text-search-curie-doc-001", "text-search-babbage-doc-001", "text-search-ada-doc-001", "code-search-babbage-code-001", "code-search-ada-code-001", ]; const CHARS_PER_TOKEN = 3.35; let spp_llama; let spp_nerd; let spp_nerd_v2; let spp_mistral; let claude_tokenizer; async function loadSentencepieceTokenizer(modelPath) { try { const spp = new SentencePieceProcessor(); await spp.load(modelPath); return spp; } catch (error) { console.error("Sentencepiece tokenizer failed to load: " + modelPath, error); return null; } }; const sentencepieceTokenizers = [ 'llama', 'nerdstash', 'nerdstash_v2', 'mistral', ]; /** * Gets the Sentencepiece tokenizer by the model name. * @param {string} model Sentencepiece model name * @returns {*} Sentencepiece tokenizer */ function getSentencepiceTokenizer(model) { if (model.includes('llama')) { return spp_llama; } if (model.includes('nerdstash')) { return spp_nerd; } if (model.includes('mistral')) { return spp_mistral; } if (model.includes('nerdstash_v2')) { return spp_nerd_v2; } } async function countSentencepieceTokens(spp, text) { // Fallback to strlen estimation if (!spp) { return { ids: [], count: Math.ceil(text.length / CHARS_PER_TOKEN) }; } let cleaned = text; // cleanText(text); <-- cleaning text can result in an incorrect tokenization let ids = spp.encodeIds(cleaned); return { ids, count: ids.length }; } async function countSentencepieceArrayTokens(tokenizer, array) { const jsonBody = array.flatMap(x => Object.values(x)).join('\n\n'); const result = await countSentencepieceTokens(tokenizer, jsonBody); const num_tokens = result.count; return num_tokens; } async function getTiktokenChunks(tokenizer, ids) { const decoder = new TextDecoder(); const chunks = []; for (let i = 0; i < ids.length; i++) { const id = ids[i]; const chunkTextBytes = await tokenizer.decode(new Uint32Array([id])); const chunkText = decoder.decode(chunkTextBytes); chunks.push(chunkText); } return chunks; } async function getWebTokenizersChunks(tokenizer, ids) { const chunks = []; for (let i = 0; i < ids.length; i++) { const id = ids[i]; const chunkText = await tokenizer.decode(new Uint32Array([id])); chunks.push(chunkText); } return chunks; } /** * Gets the tokenizer model by the model name. * @param {string} requestModel Models to use for tokenization * @returns {string} Tokenizer model to use */ function getTokenizerModel(requestModel) { if (requestModel.includes('claude')) { return 'claude'; } if (requestModel.includes('llama')) { return 'llama'; } if (requestModel.includes('mistral')) { return 'mistral'; } if (requestModel.includes('gpt-4-32k')) { return 'gpt-4-32k'; } if (requestModel.includes('gpt-4')) { return 'gpt-4'; } if (requestModel.includes('gpt-3.5-turbo-0301')) { return 'gpt-3.5-turbo-0301'; } if (requestModel.includes('gpt-3.5-turbo')) { return 'gpt-3.5-turbo'; } if (TEXT_COMPLETION_MODELS.includes(requestModel)) { return requestModel; } // default return 'gpt-3.5-turbo'; } function getTiktokenTokenizer(model) { if (tokenizersCache[model]) { return tokenizersCache[model]; } const tokenizer = tiktoken.encoding_for_model(model); console.log('Instantiated the tokenizer for', model); tokenizersCache[model] = tokenizer; return tokenizer; } async function loadClaudeTokenizer(modelPath) { try { const arrayBuffer = fs.readFileSync(modelPath).buffer; const instance = await Tokenizer.fromJSON(arrayBuffer); return instance; } catch (error) { console.error("Claude tokenizer failed to load: " + modelPath, error); return null; } } function countClaudeTokens(tokenizer, messages) { const convertedPrompt = convertClaudePrompt(messages, false, false); // Fallback to strlen estimation if (!tokenizer) { return Math.ceil(convertedPrompt.length / CHARS_PER_TOKEN); } const count = tokenizer.encode(convertedPrompt).length; return count; } /** * Creates an API handler for encoding Sentencepiece tokens. * @param {function} getTokenizerFn Tokenizer provider function * @returns {any} Handler function */ function createSentencepieceEncodingHandler(getTokenizerFn) { return async function (request, response) { try { if (!request.body) { return response.sendStatus(400); } const text = request.body.text || ''; const tokenizer = getTokenizerFn(); const { ids, count } = await countSentencepieceTokens(tokenizer, text); const chunks = await tokenizer.encodePieces(text); return response.send({ ids, count, chunks }); } catch (error) { console.log(error); return response.send({ ids: [], count: 0, chunks: [] }); } }; } /** * Creates an API handler for decoding Sentencepiece tokens. * @param {function} getTokenizerFn Tokenizer provider function * @returns {any} Handler function */ function createSentencepieceDecodingHandler(getTokenizerFn) { return async function (request, response) { try { if (!request.body) { return response.sendStatus(400); } const ids = request.body.ids || []; const tokenizer = getTokenizerFn(); const text = await tokenizer.decodeIds(ids); return response.send({ text }); } catch (error) { console.log(error); return response.send({ text: '' }); } }; } /** * Creates an API handler for encoding Tiktoken tokens. * @param {string} modelId Tiktoken model ID * @returns {any} Handler function */ function createTiktokenEncodingHandler(modelId) { return async function (request, response) { try { if (!request.body) { return response.sendStatus(400); } const text = request.body.text || ''; const tokenizer = getTiktokenTokenizer(modelId); const tokens = Object.values(tokenizer.encode(text)); const chunks = await getTiktokenChunks(tokenizer, tokens); return response.send({ ids: tokens, count: tokens.length, chunks }); } catch (error) { console.log(error); return response.send({ ids: [], count: 0, chunks: [] }); } } } /** * Creates an API handler for decoding Tiktoken tokens. * @param {string} modelId Tiktoken model ID * @returns {any} Handler function */ function createTiktokenDecodingHandler(modelId) { return async function (request, response) { try { if (!request.body) { return response.sendStatus(400); } const ids = request.body.ids || []; const tokenizer = getTiktokenTokenizer(modelId); const textBytes = tokenizer.decode(new Uint32Array(ids)); const text = new TextDecoder().decode(textBytes); return response.send({ text }); } catch (error) { console.log(error); return response.send({ text: '' }); } } } /** * Loads the model tokenizers. * @returns {Promise} Promise that resolves when the tokenizers are loaded */ async function loadTokenizers() { [spp_llama, spp_nerd, spp_nerd_v2, spp_mistral, claude_tokenizer] = await Promise.all([ loadSentencepieceTokenizer('src/sentencepiece/llama.model'), loadSentencepieceTokenizer('src/sentencepiece/nerdstash.model'), loadSentencepieceTokenizer('src/sentencepiece/nerdstash_v2.model'), loadSentencepieceTokenizer('src/sentencepiece/mistral.model'), loadClaudeTokenizer('src/claude.json'), ]); } /** * Registers the tokenization endpoints. * @param {import('express').Express} app Express app * @param {any} jsonParser JSON parser middleware */ function registerEndpoints(app, jsonParser) { app.post("/api/tokenize/ai21", jsonParser, async function (req, res) { if (!req.body) return res.sendStatus(400); const options = { method: 'POST', headers: { accept: 'application/json', 'content-type': 'application/json', Authorization: `Bearer ${readSecret(SECRET_KEYS.AI21)}` }, body: JSON.stringify({ text: req.body[0].content }) }; try { const response = await fetch('https://api.ai21.com/studio/v1/tokenize', options); const data = await response.json(); return res.send({ "token_count": data?.tokens?.length || 0 }); } catch (err) { console.error(err); return res.send({ "token_count": 0 }); } }); app.post("/api/tokenize/llama", jsonParser, createSentencepieceEncodingHandler(() => spp_llama)); app.post("/api/tokenize/nerdstash", jsonParser, createSentencepieceEncodingHandler(() => spp_nerd)); app.post("/api/tokenize/nerdstash_v2", jsonParser, createSentencepieceEncodingHandler(() => spp_nerd_v2)); app.post("/api/tokenize/mistral", jsonParser, createSentencepieceEncodingHandler(() => spp_mistral)); app.post("/api/tokenize/gpt2", jsonParser, createTiktokenEncodingHandler('gpt2')); app.post("/api/decode/llama", jsonParser, createSentencepieceDecodingHandler(() => spp_llama)); app.post("/api/decode/nerdstash", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd)); app.post("/api/decode/nerdstash_v2", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd_v2)); app.post("/api/decode/mistral", jsonParser, createSentencepieceDecodingHandler(() => spp_mistral)); app.post("/api/decode/gpt2", jsonParser, createTiktokenDecodingHandler('gpt2')); app.post("/api/tokenize/openai-encode", jsonParser, async function (req, res) { try { const queryModel = String(req.query.model || ''); if (queryModel.includes('llama')) { const handler = createSentencepieceEncodingHandler(() => spp_llama); return handler(req, res); } if (queryModel.includes('mistral')) { const handler = createSentencepieceEncodingHandler(() => spp_mistral); return handler(req, res); } if (queryModel.includes('claude')) { const text = req.body.text || ''; const tokens = Object.values(claude_tokenizer.encode(text)); const chunks = await getWebTokenizersChunks(claude_tokenizer, tokens); return res.send({ ids: tokens, count: tokens.length, chunks }); } const model = getTokenizerModel(queryModel); const handler = createTiktokenEncodingHandler(model); return handler(req, res); } catch (error) { console.log(error); return res.send({ ids: [], count: 0, chunks: [] }); } }); app.post("/api/tokenize/openai", jsonParser, async function (req, res) { try { if (!req.body) return res.sendStatus(400); let num_tokens = 0; const queryModel = String(req.query.model || ''); const model = getTokenizerModel(queryModel); if (model == 'claude') { num_tokens = countClaudeTokens(claude_tokenizer, req.body); return res.send({ "token_count": num_tokens }); } if (model == 'llama') { num_tokens = await countSentencepieceArrayTokens(spp_llama, req.body); return res.send({ "token_count": num_tokens }); } if (model == 'mistral') { num_tokens = await countSentencepieceArrayTokens(spp_mistral, req.body); return res.send({ "token_count": num_tokens }); } const tokensPerName = queryModel.includes('gpt-3.5-turbo-0301') ? -1 : 1; const tokensPerMessage = queryModel.includes('gpt-3.5-turbo-0301') ? 4 : 3; const tokensPadding = 3; const tokenizer = getTiktokenTokenizer(model); for (const msg of req.body) { try { num_tokens += tokensPerMessage; for (const [key, value] of Object.entries(msg)) { num_tokens += tokenizer.encode(value).length; if (key == "name") { num_tokens += tokensPerName; } } } catch { console.warn("Error tokenizing message:", msg); } } num_tokens += tokensPadding; // NB: Since 2023-10-14, the GPT-3.5 Turbo 0301 model shoves in 7-9 extra tokens to every message. // More details: https://community.openai.com/t/gpt-3-5-turbo-0301-showing-different-behavior-suddenly/431326/14 if (queryModel.includes('gpt-3.5-turbo-0301')) { num_tokens += 9; } // not needed for cached tokenizers //tokenizer.free(); res.send({ "token_count": num_tokens }); } catch (error) { console.error('An error counting tokens, using fallback estimation method', error); const jsonBody = JSON.stringify(req.body); const num_tokens = Math.ceil(jsonBody.length / CHARS_PER_TOKEN); res.send({ "token_count": num_tokens }); } }); } module.exports = { TEXT_COMPLETION_MODELS, getTokenizerModel, getTiktokenTokenizer, loadSentencepieceTokenizer, loadClaudeTokenizer, countSentencepieceTokens, countClaudeTokens, loadTokenizers, registerEndpoints, getSentencepiceTokenizer, sentencepieceTokenizers, }