Add lazy loading of sentencepiece tokenizers
This commit is contained in:
parent
9199750afe
commit
3fb26d3927
|
@ -2795,13 +2795,13 @@ app.post("/openai_bias", jsonParser, async function (request, response) {
|
||||||
|
|
||||||
if (sentencepieceTokenizers.includes(model)) {
|
if (sentencepieceTokenizers.includes(model)) {
|
||||||
const tokenizer = getSentencepiceTokenizer(model);
|
const tokenizer = getSentencepiceTokenizer(model);
|
||||||
encodeFunction = (text) => new Uint32Array(tokenizer.encodeIds(text));
|
const instance = await tokenizer?.get();
|
||||||
|
encodeFunction = (text) => new Uint32Array(instance?.encodeIds(text));
|
||||||
} else {
|
} else {
|
||||||
const tokenizer = getTiktokenTokenizer(model);
|
const tokenizer = getTiktokenTokenizer(model);
|
||||||
encodeFunction = (tokenizer.encode.bind(tokenizer));
|
encodeFunction = (tokenizer.encode.bind(tokenizer));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
for (const entry of request.body) {
|
for (const entry of request.body) {
|
||||||
if (!entry || !entry.text) {
|
if (!entry || !entry.text) {
|
||||||
continue;
|
continue;
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
const fs = require('fs');
|
const fs = require('fs');
|
||||||
|
const path = require('path');
|
||||||
const { SentencePieceProcessor } = require("@agnai/sentencepiece-js");
|
const { SentencePieceProcessor } = require("@agnai/sentencepiece-js");
|
||||||
const tiktoken = require('@dqbd/tiktoken');
|
const tiktoken = require('@dqbd/tiktoken');
|
||||||
const { Tokenizer } = require('@agnai/web-tokenizers');
|
const { Tokenizer } = require('@agnai/web-tokenizers');
|
||||||
|
@ -43,22 +44,39 @@ const TEXT_COMPLETION_MODELS = [
|
||||||
|
|
||||||
const CHARS_PER_TOKEN = 3.35;
|
const CHARS_PER_TOKEN = 3.35;
|
||||||
|
|
||||||
let spp_llama;
|
class SentencePieceTokenizer {
|
||||||
let spp_nerd;
|
#instance;
|
||||||
let spp_nerd_v2;
|
#model;
|
||||||
let spp_mistral;
|
|
||||||
let claude_tokenizer;
|
constructor(model) {
|
||||||
|
this.#model = model;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the Sentencepiece tokenizer instance.
|
||||||
|
*/
|
||||||
|
async get() {
|
||||||
|
if (this.#instance) {
|
||||||
|
return this.#instance;
|
||||||
|
}
|
||||||
|
|
||||||
async function loadSentencepieceTokenizer(modelPath) {
|
|
||||||
try {
|
try {
|
||||||
const spp = new SentencePieceProcessor();
|
this.#instance = new SentencePieceProcessor();
|
||||||
await spp.load(modelPath);
|
await this.#instance.load(this.#model);
|
||||||
return spp;
|
console.log('Instantiated the tokenizer for', path.parse(this.#model).name);
|
||||||
|
return this.#instance;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Sentencepiece tokenizer failed to load: " + modelPath, error);
|
console.error("Sentencepiece tokenizer failed to load: " + this.#model, error);
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const spp_llama = new SentencePieceTokenizer('src/sentencepiece/llama.model');
|
||||||
|
const spp_nerd = new SentencePieceTokenizer('src/sentencepiece/nerdstash.model');
|
||||||
|
const spp_nerd_v2 = new SentencePieceTokenizer('src/sentencepiece/nerdstash_v2.model');
|
||||||
|
const spp_mistral = new SentencePieceTokenizer('src/sentencepiece/mistral.model');
|
||||||
|
let claude_tokenizer;
|
||||||
|
|
||||||
const sentencepieceTokenizers = [
|
const sentencepieceTokenizers = [
|
||||||
'llama',
|
'llama',
|
||||||
|
@ -70,7 +88,7 @@ const sentencepieceTokenizers = [
|
||||||
/**
|
/**
|
||||||
* Gets the Sentencepiece tokenizer by the model name.
|
* Gets the Sentencepiece tokenizer by the model name.
|
||||||
* @param {string} model Sentencepiece model name
|
* @param {string} model Sentencepiece model name
|
||||||
* @returns {*} Sentencepiece tokenizer
|
* @returns {SentencePieceTokenizer|null} Sentencepiece tokenizer
|
||||||
*/
|
*/
|
||||||
function getSentencepiceTokenizer(model) {
|
function getSentencepiceTokenizer(model) {
|
||||||
if (model.includes('llama')) {
|
if (model.includes('llama')) {
|
||||||
|
@ -88,11 +106,21 @@ function getSentencepiceTokenizer(model) {
|
||||||
if (model.includes('nerdstash_v2')) {
|
if (model.includes('nerdstash_v2')) {
|
||||||
return spp_nerd_v2;
|
return spp_nerd_v2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
async function countSentencepieceTokens(spp, text) {
|
/**
|
||||||
|
* Counts the token ids for the given text using the Sentencepiece tokenizer.
|
||||||
|
* @param {SentencePieceTokenizer} tokenizer Sentencepiece tokenizer
|
||||||
|
* @param {string} text Text to tokenize
|
||||||
|
* @returns { Promise<{ids: number[], count: number}> } Tokenization result
|
||||||
|
*/
|
||||||
|
async function countSentencepieceTokens(tokenizer, text) {
|
||||||
|
const instance = await tokenizer?.get();
|
||||||
|
|
||||||
// Fallback to strlen estimation
|
// Fallback to strlen estimation
|
||||||
if (!spp) {
|
if (!instance) {
|
||||||
return {
|
return {
|
||||||
ids: [],
|
ids: [],
|
||||||
count: Math.ceil(text.length / CHARS_PER_TOKEN)
|
count: Math.ceil(text.length / CHARS_PER_TOKEN)
|
||||||
|
@ -101,13 +129,19 @@ async function countSentencepieceTokens(spp, text) {
|
||||||
|
|
||||||
let cleaned = text; // cleanText(text); <-- cleaning text can result in an incorrect tokenization
|
let cleaned = text; // cleanText(text); <-- cleaning text can result in an incorrect tokenization
|
||||||
|
|
||||||
let ids = spp.encodeIds(cleaned);
|
let ids = instance.encodeIds(cleaned);
|
||||||
return {
|
return {
|
||||||
ids,
|
ids,
|
||||||
count: ids.length
|
count: ids.length
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Counts the tokens in the given array of objects using the Sentencepiece tokenizer.
|
||||||
|
* @param {SentencePieceTokenizer} tokenizer
|
||||||
|
* @param {object[]} array Array of objects to tokenize
|
||||||
|
* @returns {Promise<number>} Number of tokens
|
||||||
|
*/
|
||||||
async function countSentencepieceArrayTokens(tokenizer, array) {
|
async function countSentencepieceArrayTokens(tokenizer, array) {
|
||||||
const jsonBody = array.flatMap(x => Object.values(x)).join('\n\n');
|
const jsonBody = array.flatMap(x => Object.values(x)).join('\n\n');
|
||||||
const result = await countSentencepieceTokens(tokenizer, jsonBody);
|
const result = await countSentencepieceTokens(tokenizer, jsonBody);
|
||||||
|
@ -219,10 +253,10 @@ function countClaudeTokens(tokenizer, messages) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates an API handler for encoding Sentencepiece tokens.
|
* Creates an API handler for encoding Sentencepiece tokens.
|
||||||
* @param {function} getTokenizerFn Tokenizer provider function
|
* @param {SentencePieceTokenizer} tokenizer Sentencepiece tokenizer
|
||||||
* @returns {any} Handler function
|
* @returns {any} Handler function
|
||||||
*/
|
*/
|
||||||
function createSentencepieceEncodingHandler(getTokenizerFn) {
|
function createSentencepieceEncodingHandler(tokenizer) {
|
||||||
return async function (request, response) {
|
return async function (request, response) {
|
||||||
try {
|
try {
|
||||||
if (!request.body) {
|
if (!request.body) {
|
||||||
|
@ -230,9 +264,9 @@ function createSentencepieceEncodingHandler(getTokenizerFn) {
|
||||||
}
|
}
|
||||||
|
|
||||||
const text = request.body.text || '';
|
const text = request.body.text || '';
|
||||||
const tokenizer = getTokenizerFn();
|
const instance = await tokenizer?.get();
|
||||||
const { ids, count } = await countSentencepieceTokens(tokenizer, text);
|
const { ids, count } = await countSentencepieceTokens(tokenizer, text);
|
||||||
const chunks = await tokenizer.encodePieces(text);
|
const chunks = await instance?.encodePieces(text);
|
||||||
return response.send({ ids, count, chunks });
|
return response.send({ ids, count, chunks });
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.log(error);
|
console.log(error);
|
||||||
|
@ -243,10 +277,10 @@ function createSentencepieceEncodingHandler(getTokenizerFn) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates an API handler for decoding Sentencepiece tokens.
|
* Creates an API handler for decoding Sentencepiece tokens.
|
||||||
* @param {function} getTokenizerFn Tokenizer provider function
|
* @param {SentencePieceTokenizer} tokenizer Sentencepiece tokenizer
|
||||||
* @returns {any} Handler function
|
* @returns {any} Handler function
|
||||||
*/
|
*/
|
||||||
function createSentencepieceDecodingHandler(getTokenizerFn) {
|
function createSentencepieceDecodingHandler(tokenizer) {
|
||||||
return async function (request, response) {
|
return async function (request, response) {
|
||||||
try {
|
try {
|
||||||
if (!request.body) {
|
if (!request.body) {
|
||||||
|
@ -254,8 +288,8 @@ function createSentencepieceDecodingHandler(getTokenizerFn) {
|
||||||
}
|
}
|
||||||
|
|
||||||
const ids = request.body.ids || [];
|
const ids = request.body.ids || [];
|
||||||
const tokenizer = getTokenizerFn();
|
const instance = await tokenizer?.get();
|
||||||
const text = await tokenizer.decodeIds(ids);
|
const text = await instance?.decodeIds(ids);
|
||||||
return response.send({ text });
|
return response.send({ text });
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.log(error);
|
console.log(error);
|
||||||
|
@ -317,13 +351,7 @@ function createTiktokenDecodingHandler(modelId) {
|
||||||
* @returns {Promise<void>} Promise that resolves when the tokenizers are loaded
|
* @returns {Promise<void>} Promise that resolves when the tokenizers are loaded
|
||||||
*/
|
*/
|
||||||
async function loadTokenizers() {
|
async function loadTokenizers() {
|
||||||
[spp_llama, spp_nerd, spp_nerd_v2, spp_mistral, claude_tokenizer] = await Promise.all([
|
claude_tokenizer = await loadClaudeTokenizer('src/claude.json');
|
||||||
loadSentencepieceTokenizer('src/sentencepiece/llama.model'),
|
|
||||||
loadSentencepieceTokenizer('src/sentencepiece/nerdstash.model'),
|
|
||||||
loadSentencepieceTokenizer('src/sentencepiece/nerdstash_v2.model'),
|
|
||||||
loadSentencepieceTokenizer('src/sentencepiece/mistral.model'),
|
|
||||||
loadClaudeTokenizer('src/claude.json'),
|
|
||||||
]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -354,15 +382,15 @@ function registerEndpoints(app, jsonParser) {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
app.post("/api/tokenize/llama", jsonParser, createSentencepieceEncodingHandler(() => spp_llama));
|
app.post("/api/tokenize/llama", jsonParser, createSentencepieceEncodingHandler(spp_llama));
|
||||||
app.post("/api/tokenize/nerdstash", jsonParser, createSentencepieceEncodingHandler(() => spp_nerd));
|
app.post("/api/tokenize/nerdstash", jsonParser, createSentencepieceEncodingHandler(spp_nerd));
|
||||||
app.post("/api/tokenize/nerdstash_v2", jsonParser, createSentencepieceEncodingHandler(() => spp_nerd_v2));
|
app.post("/api/tokenize/nerdstash_v2", jsonParser, createSentencepieceEncodingHandler(spp_nerd_v2));
|
||||||
app.post("/api/tokenize/mistral", jsonParser, createSentencepieceEncodingHandler(() => spp_mistral));
|
app.post("/api/tokenize/mistral", jsonParser, createSentencepieceEncodingHandler(spp_mistral));
|
||||||
app.post("/api/tokenize/gpt2", jsonParser, createTiktokenEncodingHandler('gpt2'));
|
app.post("/api/tokenize/gpt2", jsonParser, createTiktokenEncodingHandler('gpt2'));
|
||||||
app.post("/api/decode/llama", jsonParser, createSentencepieceDecodingHandler(() => spp_llama));
|
app.post("/api/decode/llama", jsonParser, createSentencepieceDecodingHandler(spp_llama));
|
||||||
app.post("/api/decode/nerdstash", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd));
|
app.post("/api/decode/nerdstash", jsonParser, createSentencepieceDecodingHandler(spp_nerd));
|
||||||
app.post("/api/decode/nerdstash_v2", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd_v2));
|
app.post("/api/decode/nerdstash_v2", jsonParser, createSentencepieceDecodingHandler(spp_nerd_v2));
|
||||||
app.post("/api/decode/mistral", jsonParser, createSentencepieceDecodingHandler(() => spp_mistral));
|
app.post("/api/decode/mistral", jsonParser, createSentencepieceDecodingHandler(spp_mistral));
|
||||||
app.post("/api/decode/gpt2", jsonParser, createTiktokenDecodingHandler('gpt2'));
|
app.post("/api/decode/gpt2", jsonParser, createTiktokenDecodingHandler('gpt2'));
|
||||||
|
|
||||||
app.post("/api/tokenize/openai-encode", jsonParser, async function (req, res) {
|
app.post("/api/tokenize/openai-encode", jsonParser, async function (req, res) {
|
||||||
|
@ -370,12 +398,12 @@ function registerEndpoints(app, jsonParser) {
|
||||||
const queryModel = String(req.query.model || '');
|
const queryModel = String(req.query.model || '');
|
||||||
|
|
||||||
if (queryModel.includes('llama')) {
|
if (queryModel.includes('llama')) {
|
||||||
const handler = createSentencepieceEncodingHandler(() => spp_llama);
|
const handler = createSentencepieceEncodingHandler(spp_llama);
|
||||||
return handler(req, res);
|
return handler(req, res);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (queryModel.includes('mistral')) {
|
if (queryModel.includes('mistral')) {
|
||||||
const handler = createSentencepieceEncodingHandler(() => spp_mistral);
|
const handler = createSentencepieceEncodingHandler(spp_mistral);
|
||||||
return handler(req, res);
|
return handler(req, res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -462,9 +490,6 @@ module.exports = {
|
||||||
TEXT_COMPLETION_MODELS,
|
TEXT_COMPLETION_MODELS,
|
||||||
getTokenizerModel,
|
getTokenizerModel,
|
||||||
getTiktokenTokenizer,
|
getTiktokenTokenizer,
|
||||||
loadSentencepieceTokenizer,
|
|
||||||
loadClaudeTokenizer,
|
|
||||||
countSentencepieceTokens,
|
|
||||||
countClaudeTokens,
|
countClaudeTokens,
|
||||||
loadTokenizers,
|
loadTokenizers,
|
||||||
registerEndpoints,
|
registerEndpoints,
|
||||||
|
|
Loading…
Reference in New Issue