Lazy initialization of Claude tokenizer. Add JSDoc for tokenizer handlers
This commit is contained in:
parent
1b60e4a013
commit
212e61d2a1
|
@ -45,7 +45,6 @@ const {
|
||||||
forwardFetchResponse,
|
forwardFetchResponse,
|
||||||
} = require('./src/util');
|
} = require('./src/util');
|
||||||
const { ensureThumbnailCache } = require('./src/endpoints/thumbnails');
|
const { ensureThumbnailCache } = require('./src/endpoints/thumbnails');
|
||||||
const { loadTokenizers } = require('./src/endpoints/tokenizers');
|
|
||||||
|
|
||||||
// Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0.
|
// Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0.
|
||||||
// https://github.com/nodejs/node/issues/47822#issuecomment-1564708870
|
// https://github.com/nodejs/node/issues/47822#issuecomment-1564708870
|
||||||
|
@ -548,7 +547,6 @@ const setupTasks = async function () {
|
||||||
await ensureThumbnailCache();
|
await ensureThumbnailCache();
|
||||||
cleanUploads();
|
cleanUploads();
|
||||||
|
|
||||||
await loadTokenizers();
|
|
||||||
await settingsEndpoint.init();
|
await settingsEndpoint.init();
|
||||||
await statsEndpoint.init();
|
await statsEndpoint.init();
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,10 @@ const { TEXTGEN_TYPES } = require('../constants');
|
||||||
const { jsonParser } = require('../express-common');
|
const { jsonParser } = require('../express-common');
|
||||||
const { setAdditionalHeaders } = require('../additional-headers');
|
const { setAdditionalHeaders } = require('../additional-headers');
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @typedef { (req: import('express').Request, res: import('express').Response) => Promise<any> } TokenizationHandler
|
||||||
|
*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @type {{[key: string]: import("@dqbd/tiktoken").Tiktoken}} Tokenizers cache
|
* @type {{[key: string]: import("@dqbd/tiktoken").Tiktoken}} Tokenizers cache
|
||||||
*/
|
*/
|
||||||
|
@ -48,16 +52,30 @@ const TEXT_COMPLETION_MODELS = [
|
||||||
|
|
||||||
const CHARS_PER_TOKEN = 3.35;
|
const CHARS_PER_TOKEN = 3.35;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sentencepiece tokenizer for tokenizing text.
|
||||||
|
*/
|
||||||
class SentencePieceTokenizer {
|
class SentencePieceTokenizer {
|
||||||
|
/**
|
||||||
|
* @type {import('@agnai/sentencepiece-js').SentencePieceProcessor} Sentencepiece tokenizer instance
|
||||||
|
*/
|
||||||
#instance;
|
#instance;
|
||||||
|
/**
|
||||||
|
* @type {string} Path to the tokenizer model
|
||||||
|
*/
|
||||||
#model;
|
#model;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new Sentencepiece tokenizer.
|
||||||
|
* @param {string} model Path to the tokenizer model
|
||||||
|
*/
|
||||||
constructor(model) {
|
constructor(model) {
|
||||||
this.#model = model;
|
this.#model = model;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets the Sentencepiece tokenizer instance.
|
* Gets the Sentencepiece tokenizer instance.
|
||||||
|
* @returns {Promise<import('@agnai/sentencepiece-js').SentencePieceProcessor|null>} Sentencepiece tokenizer instance
|
||||||
*/
|
*/
|
||||||
async get() {
|
async get() {
|
||||||
if (this.#instance) {
|
if (this.#instance) {
|
||||||
|
@ -76,18 +94,61 @@ class SentencePieceTokenizer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const spp_llama = new SentencePieceTokenizer('src/sentencepiece/llama.model');
|
/**
|
||||||
const spp_nerd = new SentencePieceTokenizer('src/sentencepiece/nerdstash.model');
|
* Web tokenizer for tokenizing text.
|
||||||
const spp_nerd_v2 = new SentencePieceTokenizer('src/sentencepiece/nerdstash_v2.model');
|
*/
|
||||||
const spp_mistral = new SentencePieceTokenizer('src/sentencepiece/mistral.model');
|
class WebTokenizer {
|
||||||
const spp_yi = new SentencePieceTokenizer('src/sentencepiece/yi.model');
|
/**
|
||||||
let claude_tokenizer;
|
* @type {Tokenizer} Web tokenizer instance
|
||||||
|
*/
|
||||||
|
#instance;
|
||||||
|
/**
|
||||||
|
* @type {string} Path to the tokenizer model
|
||||||
|
*/
|
||||||
|
#model;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new Web tokenizer.
|
||||||
|
* @param {string} model Path to the tokenizer model
|
||||||
|
*/
|
||||||
|
constructor(model) {
|
||||||
|
this.#model = model;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the Web tokenizer instance.
|
||||||
|
* @returns {Promise<Tokenizer|null>} Web tokenizer instance
|
||||||
|
*/
|
||||||
|
async get() {
|
||||||
|
if (this.#instance) {
|
||||||
|
return this.#instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const arrayBuffer = fs.readFileSync(this.#model).buffer;
|
||||||
|
this.#instance = await Tokenizer.fromJSON(arrayBuffer);
|
||||||
|
console.log('Instantiated the tokenizer for', path.parse(this.#model).name);
|
||||||
|
return this.#instance;
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Web tokenizer failed to load: ' + this.#model, error);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const spp_llama = new SentencePieceTokenizer('src/tokenizers/llama.model');
|
||||||
|
const spp_nerd = new SentencePieceTokenizer('src/tokenizers/nerdstash.model');
|
||||||
|
const spp_nerd_v2 = new SentencePieceTokenizer('src/tokenizers/nerdstash_v2.model');
|
||||||
|
const spp_mistral = new SentencePieceTokenizer('src/tokenizers/mistral.model');
|
||||||
|
const spp_yi = new SentencePieceTokenizer('src/tokenizers/yi.model');
|
||||||
|
const claude_tokenizer = new WebTokenizer('src/tokenizers/claude.json');
|
||||||
|
|
||||||
const sentencepieceTokenizers = [
|
const sentencepieceTokenizers = [
|
||||||
'llama',
|
'llama',
|
||||||
'nerdstash',
|
'nerdstash',
|
||||||
'nerdstash_v2',
|
'nerdstash_v2',
|
||||||
'mistral',
|
'mistral',
|
||||||
|
'yi',
|
||||||
];
|
];
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -112,6 +173,10 @@ function getSentencepiceTokenizer(model) {
|
||||||
return spp_nerd_v2;
|
return spp_nerd_v2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (model.includes('yi')) {
|
||||||
|
return spp_yi;
|
||||||
|
}
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -168,13 +233,23 @@ async function getTiktokenChunks(tokenizer, ids) {
|
||||||
return chunks;
|
return chunks;
|
||||||
}
|
}
|
||||||
|
|
||||||
async function getWebTokenizersChunks(tokenizer, ids) {
|
/**
|
||||||
|
* Gets the token chunks for the given token IDs using the Web tokenizer.
|
||||||
|
* @param {Tokenizer} tokenizer Web tokenizer instance
|
||||||
|
* @param {number[]} ids Token IDs
|
||||||
|
* @returns {string[]} Token chunks
|
||||||
|
*/
|
||||||
|
function getWebTokenizersChunks(tokenizer, ids) {
|
||||||
const chunks = [];
|
const chunks = [];
|
||||||
|
|
||||||
for (let i = 0; i < ids.length; i++) {
|
for (let i = 0, lastProcessed = 0; i < ids.length; i++) {
|
||||||
const id = ids[i];
|
const chunkIds = ids.slice(lastProcessed, i + 1);
|
||||||
const chunkText = await tokenizer.decode(new Uint32Array([id]));
|
const chunkText = tokenizer.decode(new Int32Array(chunkIds));
|
||||||
|
if (chunkText === '<27>') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
chunks.push(chunkText);
|
chunks.push(chunkText);
|
||||||
|
lastProcessed = i + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
return chunks;
|
return chunks;
|
||||||
|
@ -237,17 +312,12 @@ function getTiktokenTokenizer(model) {
|
||||||
return tokenizer;
|
return tokenizer;
|
||||||
}
|
}
|
||||||
|
|
||||||
async function loadClaudeTokenizer(modelPath) {
|
/**
|
||||||
try {
|
* Counts the tokens for the given messages using the Claude tokenizer.
|
||||||
const arrayBuffer = fs.readFileSync(modelPath).buffer;
|
* @param {Tokenizer} tokenizer Web tokenizer
|
||||||
const instance = await Tokenizer.fromJSON(arrayBuffer);
|
* @param {object[]} messages Array of messages
|
||||||
return instance;
|
* @returns {number} Number of tokens
|
||||||
} catch (error) {
|
*/
|
||||||
console.error('Claude tokenizer failed to load: ' + modelPath, error);
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function countClaudeTokens(tokenizer, messages) {
|
function countClaudeTokens(tokenizer, messages) {
|
||||||
// Should be fine if we use the old conversion method instead of the messages API one i think?
|
// Should be fine if we use the old conversion method instead of the messages API one i think?
|
||||||
const convertedPrompt = convertClaudePrompt(messages, false, '', false, false, '', false);
|
const convertedPrompt = convertClaudePrompt(messages, false, '', false, false, '', false);
|
||||||
|
@ -264,9 +334,14 @@ function countClaudeTokens(tokenizer, messages) {
|
||||||
/**
|
/**
|
||||||
* Creates an API handler for encoding Sentencepiece tokens.
|
* Creates an API handler for encoding Sentencepiece tokens.
|
||||||
* @param {SentencePieceTokenizer} tokenizer Sentencepiece tokenizer
|
* @param {SentencePieceTokenizer} tokenizer Sentencepiece tokenizer
|
||||||
* @returns {any} Handler function
|
* @returns {TokenizationHandler} Handler function
|
||||||
*/
|
*/
|
||||||
function createSentencepieceEncodingHandler(tokenizer) {
|
function createSentencepieceEncodingHandler(tokenizer) {
|
||||||
|
/**
|
||||||
|
* Request handler for encoding Sentencepiece tokens.
|
||||||
|
* @param {import('express').Request} request
|
||||||
|
* @param {import('express').Response} response
|
||||||
|
*/
|
||||||
return async function (request, response) {
|
return async function (request, response) {
|
||||||
try {
|
try {
|
||||||
if (!request.body) {
|
if (!request.body) {
|
||||||
|
@ -276,7 +351,7 @@ function createSentencepieceEncodingHandler(tokenizer) {
|
||||||
const text = request.body.text || '';
|
const text = request.body.text || '';
|
||||||
const instance = await tokenizer?.get();
|
const instance = await tokenizer?.get();
|
||||||
const { ids, count } = await countSentencepieceTokens(tokenizer, text);
|
const { ids, count } = await countSentencepieceTokens(tokenizer, text);
|
||||||
const chunks = await instance?.encodePieces(text);
|
const chunks = instance?.encodePieces(text);
|
||||||
return response.send({ ids, count, chunks });
|
return response.send({ ids, count, chunks });
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.log(error);
|
console.log(error);
|
||||||
|
@ -288,9 +363,14 @@ function createSentencepieceEncodingHandler(tokenizer) {
|
||||||
/**
|
/**
|
||||||
* Creates an API handler for decoding Sentencepiece tokens.
|
* Creates an API handler for decoding Sentencepiece tokens.
|
||||||
* @param {SentencePieceTokenizer} tokenizer Sentencepiece tokenizer
|
* @param {SentencePieceTokenizer} tokenizer Sentencepiece tokenizer
|
||||||
* @returns {any} Handler function
|
* @returns {TokenizationHandler} Handler function
|
||||||
*/
|
*/
|
||||||
function createSentencepieceDecodingHandler(tokenizer) {
|
function createSentencepieceDecodingHandler(tokenizer) {
|
||||||
|
/**
|
||||||
|
* Request handler for decoding Sentencepiece tokens.
|
||||||
|
* @param {import('express').Request} request
|
||||||
|
* @param {import('express').Response} response
|
||||||
|
*/
|
||||||
return async function (request, response) {
|
return async function (request, response) {
|
||||||
try {
|
try {
|
||||||
if (!request.body) {
|
if (!request.body) {
|
||||||
|
@ -299,6 +379,7 @@ function createSentencepieceDecodingHandler(tokenizer) {
|
||||||
|
|
||||||
const ids = request.body.ids || [];
|
const ids = request.body.ids || [];
|
||||||
const instance = await tokenizer?.get();
|
const instance = await tokenizer?.get();
|
||||||
|
if (!instance) throw new Error('Failed to load the Sentencepiece tokenizer');
|
||||||
const ops = ids.map(id => instance.decodeIds([id]));
|
const ops = ids.map(id => instance.decodeIds([id]));
|
||||||
const chunks = await Promise.all(ops);
|
const chunks = await Promise.all(ops);
|
||||||
const text = chunks.join('');
|
const text = chunks.join('');
|
||||||
|
@ -313,9 +394,14 @@ function createSentencepieceDecodingHandler(tokenizer) {
|
||||||
/**
|
/**
|
||||||
* Creates an API handler for encoding Tiktoken tokens.
|
* Creates an API handler for encoding Tiktoken tokens.
|
||||||
* @param {string} modelId Tiktoken model ID
|
* @param {string} modelId Tiktoken model ID
|
||||||
* @returns {any} Handler function
|
* @returns {TokenizationHandler} Handler function
|
||||||
*/
|
*/
|
||||||
function createTiktokenEncodingHandler(modelId) {
|
function createTiktokenEncodingHandler(modelId) {
|
||||||
|
/**
|
||||||
|
* Request handler for encoding Tiktoken tokens.
|
||||||
|
* @param {import('express').Request} request
|
||||||
|
* @param {import('express').Response} response
|
||||||
|
*/
|
||||||
return async function (request, response) {
|
return async function (request, response) {
|
||||||
try {
|
try {
|
||||||
if (!request.body) {
|
if (!request.body) {
|
||||||
|
@ -337,9 +423,14 @@ function createTiktokenEncodingHandler(modelId) {
|
||||||
/**
|
/**
|
||||||
* Creates an API handler for decoding Tiktoken tokens.
|
* Creates an API handler for decoding Tiktoken tokens.
|
||||||
* @param {string} modelId Tiktoken model ID
|
* @param {string} modelId Tiktoken model ID
|
||||||
* @returns {any} Handler function
|
* @returns {TokenizationHandler} Handler function
|
||||||
*/
|
*/
|
||||||
function createTiktokenDecodingHandler(modelId) {
|
function createTiktokenDecodingHandler(modelId) {
|
||||||
|
/**
|
||||||
|
* Request handler for decoding Tiktoken tokens.
|
||||||
|
* @param {import('express').Request} request
|
||||||
|
* @param {import('express').Response} response
|
||||||
|
*/
|
||||||
return async function (request, response) {
|
return async function (request, response) {
|
||||||
try {
|
try {
|
||||||
if (!request.body) {
|
if (!request.body) {
|
||||||
|
@ -358,14 +449,6 @@ function createTiktokenDecodingHandler(modelId) {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Loads the model tokenizers.
|
|
||||||
* @returns {Promise<void>} Promise that resolves when the tokenizers are loaded
|
|
||||||
*/
|
|
||||||
async function loadTokenizers() {
|
|
||||||
claude_tokenizer = await loadClaudeTokenizer('src/claude.json');
|
|
||||||
}
|
|
||||||
|
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
|
|
||||||
router.post('/ai21/count', jsonParser, async function (req, res) {
|
router.post('/ai21/count', jsonParser, async function (req, res) {
|
||||||
|
@ -446,8 +529,10 @@ router.post('/openai/encode', jsonParser, async function (req, res) {
|
||||||
|
|
||||||
if (queryModel.includes('claude')) {
|
if (queryModel.includes('claude')) {
|
||||||
const text = req.body.text || '';
|
const text = req.body.text || '';
|
||||||
const tokens = Object.values(claude_tokenizer.encode(text));
|
const instance = await claude_tokenizer.get();
|
||||||
const chunks = await getWebTokenizersChunks(claude_tokenizer, tokens);
|
if (!instance) throw new Error('Failed to load the Claude tokenizer');
|
||||||
|
const tokens = Object.values(instance.encode(text));
|
||||||
|
const chunks = getWebTokenizersChunks(instance, tokens);
|
||||||
return res.send({ ids: tokens, count: tokens.length, chunks });
|
return res.send({ ids: tokens, count: tokens.length, chunks });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -481,7 +566,9 @@ router.post('/openai/decode', jsonParser, async function (req, res) {
|
||||||
|
|
||||||
if (queryModel.includes('claude')) {
|
if (queryModel.includes('claude')) {
|
||||||
const ids = req.body.ids || [];
|
const ids = req.body.ids || [];
|
||||||
const chunkText = await claude_tokenizer.decode(new Uint32Array(ids));
|
const instance = await claude_tokenizer.get();
|
||||||
|
if (!instance) throw new Error('Failed to load the Claude tokenizer');
|
||||||
|
const chunkText = instance.decode(new Int32Array(ids));
|
||||||
return res.send({ text: chunkText });
|
return res.send({ text: chunkText });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -503,7 +590,9 @@ router.post('/openai/count', jsonParser, async function (req, res) {
|
||||||
const model = getTokenizerModel(queryModel);
|
const model = getTokenizerModel(queryModel);
|
||||||
|
|
||||||
if (model === 'claude') {
|
if (model === 'claude') {
|
||||||
num_tokens = countClaudeTokens(claude_tokenizer, req.body);
|
const instance = await claude_tokenizer.get();
|
||||||
|
if (!instance) throw new Error('Failed to load the Claude tokenizer');
|
||||||
|
num_tokens = countClaudeTokens(instance, req.body);
|
||||||
return res.send({ 'token_count': num_tokens });
|
return res.send({ 'token_count': num_tokens });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -665,7 +754,6 @@ module.exports = {
|
||||||
getTokenizerModel,
|
getTokenizerModel,
|
||||||
getTiktokenTokenizer,
|
getTiktokenTokenizer,
|
||||||
countClaudeTokens,
|
countClaudeTokens,
|
||||||
loadTokenizers,
|
|
||||||
getSentencepiceTokenizer,
|
getSentencepiceTokenizer,
|
||||||
sentencepieceTokenizers,
|
sentencepieceTokenizers,
|
||||||
router,
|
router,
|
||||||
|
|
Loading…
Reference in New Issue