Implement downloadable tokenizers

Closes #2574, #2754
This commit is contained in:
Cohee
2024-09-06 16:28:34 +00:00
parent 4a9401bfe2
commit 81251b073a
6 changed files with 187 additions and 10 deletions

View File

@ -4,13 +4,12 @@ const express = require('express');
const { SentencePieceProcessor } = require('@agnai/sentencepiece-js');
const tiktoken = require('tiktoken');
const { Tokenizer } = require('@agnai/web-tokenizers');
const { convertClaudePrompt, convertGooglePrompt } = require('../prompt-converters');
const { readSecret, SECRET_KEYS } = require('./secrets');
const { convertClaudePrompt } = require('../prompt-converters');
const { TEXTGEN_TYPES } = require('../constants');
const { jsonParser } = require('../express-common');
const { setAdditionalHeaders } = require('../additional-headers');
const API_MAKERSUITE = 'https://generativelanguage.googleapis.com';
const { getConfigValue, isValidUrl } = require('../util');
const writeFileAtomicSync = require('write-file-atomic').sync;
/**
* @typedef { (req: import('express').Request, res: import('express').Response) => Promise<any> } TokenizationHandler
@ -53,6 +52,65 @@ const TEXT_COMPLETION_MODELS = [
];
const CHARS_PER_TOKEN = 3.35;
const IS_DOWNLOAD_ALLOWED = getConfigValue('enableDownloadableTokenizers', true);
/**
* Gets a path to the tokenizer model. Downloads the model if it's a URL.
* @param {string} model Model URL or path
* @param {string|undefined} fallbackModel Fallback model path\
* @returns {Promise<string>} Path to the tokenizer model
*/
async function getPathToTokenizer(model, fallbackModel) {
if (!isValidUrl(model)) {
return model;
}
try {
const url = new URL(model);
if (!['https:', 'http:'].includes(url.protocol)) {
throw new Error('Invalid URL protocol');
}
const fileName = url.pathname.split('/').pop();
if (!fileName) {
throw new Error('Failed to extract the file name from the URL');
}
const CACHE_PATH = path.join(global.DATA_ROOT, '_cache');
if (!fs.existsSync(CACHE_PATH)) {
fs.mkdirSync(CACHE_PATH, { recursive: true });
}
const cachedFile = path.join(CACHE_PATH, fileName);
if (fs.existsSync(cachedFile)) {
return cachedFile;
}
if (!IS_DOWNLOAD_ALLOWED) {
throw new Error('Downloading tokenizers is disabled, the model is not cached');
}
console.log('Downloading tokenizer model:', model);
const response = await fetch(model);
if (!response.ok) {
throw new Error(`Failed to fetch the model: ${response.status} ${response.statusText}`);
}
const arrayBuffer = await response.arrayBuffer();
writeFileAtomicSync(cachedFile, Buffer.from(arrayBuffer));
return cachedFile;
} catch (error) {
const getLastSegment = str => str?.split('/')?.pop() || '';
if (fallbackModel) {
console.log(`Could not get a tokenizer from ${getLastSegment(model)}. Reason: ${error.message}. Using a fallback model: ${getLastSegment(fallbackModel)}.`);
return fallbackModel;
}
throw new Error(`Failed to instantiate a tokenizer and fallback is not provided. Reason: ${error.message}`);
}
}
/**
* Sentencepiece tokenizer for tokenizing text.
@ -66,13 +124,19 @@ class SentencePieceTokenizer {
* @type {string} Path to the tokenizer model
*/
#model;
/**
* @type {string|undefined} Path to the fallback model
*/
#fallbackModel;
/**
* Creates a new Sentencepiece tokenizer.
* @param {string} model Path to the tokenizer model
* @param {string} [fallbackModel] Path to the fallback model
*/
constructor(model) {
constructor(model, fallbackModel) {
this.#model = model;
this.#fallbackModel = fallbackModel;
}
/**
@ -85,9 +149,10 @@ class SentencePieceTokenizer {
}
try {
const pathToModel = await getPathToTokenizer(this.#model, this.#fallbackModel);
this.#instance = new SentencePieceProcessor();
await this.#instance.load(this.#model);
console.log('Instantiated the tokenizer for', path.parse(this.#model).name);
await this.#instance.load(pathToModel);
console.log('Instantiated the tokenizer for', path.parse(pathToModel).name);
return this.#instance;
} catch (error) {
console.error('Sentencepiece tokenizer failed to load: ' + this.#model, error);
@ -108,13 +173,19 @@ class WebTokenizer {
* @type {string} Path to the tokenizer model
*/
#model;
/**
* @type {string|undefined} Path to the fallback model
*/
#fallbackModel;
/**
* Creates a new Web tokenizer.
* @param {string} model Path to the tokenizer model
* @param {string} [fallbackModel] Path to the fallback model
*/
constructor(model) {
constructor(model, fallbackModel) {
this.#model = model;
this.#fallbackModel = fallbackModel;
}
/**
@ -127,9 +198,10 @@ class WebTokenizer {
}
try {
const arrayBuffer = fs.readFileSync(this.#model).buffer;
const pathToModel = await getPathToTokenizer(this.#model, this.#fallbackModel);
const arrayBuffer = fs.readFileSync(pathToModel).buffer;
this.#instance = await Tokenizer.fromJSON(arrayBuffer);
console.log('Instantiated the tokenizer for', path.parse(this.#model).name);
console.log('Instantiated the tokenizer for', path.parse(pathToModel).name);
return this.#instance;
} catch (error) {
console.error('Web tokenizer failed to load: ' + this.#model, error);
@ -147,6 +219,8 @@ const spp_gemma = new SentencePieceTokenizer('src/tokenizers/gemma.model');
const spp_jamba = new SentencePieceTokenizer('src/tokenizers/jamba.model');
const claude_tokenizer = new WebTokenizer('src/tokenizers/claude.json');
const llama3_tokenizer = new WebTokenizer('src/tokenizers/llama3.json');
const commandTokenizer = new WebTokenizer('https://github.com/SillyTavern/SillyTavern-Tokenizers/raw/main/command-r.json', 'src/tokenizers/llama3.json');
const qwen2Tokenizer = new WebTokenizer('https://github.com/SillyTavern/SillyTavern-Tokenizers/raw/main/qwen2.json', 'src/tokenizers/llama3.json');
const sentencepieceTokenizers = [
'llama',
@ -332,6 +406,14 @@ function getTokenizerModel(requestModel) {
return 'jamba';
}
if (requestModel.includes('qwen2')) {
return 'qwen2';
}
if (requestModel.includes('command-r')) {
return 'command-r';
}
// default
return 'gpt-3.5-turbo';
}
@ -557,6 +639,8 @@ router.post('/jamba/encode', jsonParser, createSentencepieceEncodingHandler(spp_
router.post('/gpt2/encode', jsonParser, createTiktokenEncodingHandler('gpt2'));
router.post('/claude/encode', jsonParser, createWebTokenizerEncodingHandler(claude_tokenizer));
router.post('/llama3/encode', jsonParser, createWebTokenizerEncodingHandler(llama3_tokenizer));
router.post('/qwen2/encode', jsonParser, createWebTokenizerEncodingHandler(qwen2Tokenizer));
router.post('/command-r/encode', jsonParser, createWebTokenizerEncodingHandler(commandTokenizer));
router.post('/llama/decode', jsonParser, createSentencepieceDecodingHandler(spp_llama));
router.post('/nerdstash/decode', jsonParser, createSentencepieceDecodingHandler(spp_nerd));
router.post('/nerdstash_v2/decode', jsonParser, createSentencepieceDecodingHandler(spp_nerd_v2));
@ -567,6 +651,8 @@ router.post('/jamba/decode', jsonParser, createSentencepieceDecodingHandler(spp_
router.post('/gpt2/decode', jsonParser, createTiktokenDecodingHandler('gpt2'));
router.post('/claude/decode', jsonParser, createWebTokenizerDecodingHandler(claude_tokenizer));
router.post('/llama3/decode', jsonParser, createWebTokenizerDecodingHandler(llama3_tokenizer));
router.post('/qwen2/decode', jsonParser, createWebTokenizerDecodingHandler(qwen2Tokenizer));
router.post('/command-r/decode', jsonParser, createWebTokenizerDecodingHandler(commandTokenizer));
router.post('/openai/encode', jsonParser, async function (req, res) {
try {
@ -607,6 +693,16 @@ router.post('/openai/encode', jsonParser, async function (req, res) {
return handler(req, res);
}
if (queryModel.includes('qwen2')) {
const handler = createWebTokenizerEncodingHandler(qwen2Tokenizer);
return handler(req, res);
}
if (queryModel.includes('command-r')) {
const handler = createWebTokenizerEncodingHandler(commandTokenizer);
return handler(req, res);
}
const model = getTokenizerModel(queryModel);
const handler = createTiktokenEncodingHandler(model);
return handler(req, res);
@ -655,6 +751,16 @@ router.post('/openai/decode', jsonParser, async function (req, res) {
return handler(req, res);
}
if (queryModel.includes('qwen2')) {
const handler = createWebTokenizerDecodingHandler(qwen2Tokenizer);
return handler(req, res);
}
if (queryModel.includes('command-r')) {
const handler = createWebTokenizerDecodingHandler(commandTokenizer);
return handler(req, res);
}
const model = getTokenizerModel(queryModel);
const handler = createTiktokenDecodingHandler(model);
return handler(req, res);
@ -711,6 +817,20 @@ router.post('/openai/count', jsonParser, async function (req, res) {
return res.send({ 'token_count': num_tokens });
}
if (model === 'qwen2') {
const instance = await qwen2Tokenizer.get();
if (!instance) throw new Error('Failed to load the Qwen2 tokenizer');
num_tokens = countWebTokenizerTokens(instance, req.body);
return res.send({ 'token_count': num_tokens });
}
if (model === 'command-r') {
const instance = await commandTokenizer.get();
if (!instance) throw new Error('Failed to load the Command-R tokenizer');
num_tokens = countWebTokenizerTokens(instance, req.body);
return res.send({ 'token_count': num_tokens });
}
const tokensPerName = queryModel.includes('gpt-3.5-turbo-0301') ? -1 : 1;
const tokensPerMessage = queryModel.includes('gpt-3.5-turbo-0301') ? 4 : 3;
const tokensPadding = 3;