Implement downloadable tokenizers

Closes #2574, #2754
This commit is contained in:
Cohee 2024-09-06 16:28:34 +00:00
parent 4a9401bfe2
commit 81251b073a
6 changed files with 187 additions and 10 deletions

View File

@ -95,6 +95,9 @@ requestOverrides: []
enableExtensions: true
# Automatically update extensions when a release version changes
enableExtensionsAutoUpdate: true
# Additional model tokenizers can be downloaded on demand.
# Disabling will fallback to another locally available tokenizer.
enableDownloadableTokenizers: true
# Extension settings
extras:
# Disables automatic model download from HuggingFace

View File

@ -3283,6 +3283,8 @@
<option value="12">Llama 3</option>
<option value="13">Gemma / Gemini</option>
<option value="14">Jamba</option>
<option value="15">Qwen2</option>
<option value="16">Command-R</option>
<option value="4">NerdStash (NovelAI Clio)</option>
<option value="5">NerdStash v2 (NovelAI Kayra)</option>
<option value="7">Mistral</option>

View File

@ -590,6 +590,9 @@ function calculateOpenRouterCost() {
export function getCurrentOpenRouterModelTokenizer() {
const modelId = textgen_settings.openrouter_model;
const model = openRouterModels.find(x => x.id === modelId);
if (modelId?.includes('jamba')) {
return tokenizers.JAMBA;
}
switch (model?.architecture?.tokenizer) {
case 'Llama2':
return tokenizers.LLAMA;
@ -603,6 +606,10 @@ export function getCurrentOpenRouterModelTokenizer() {
return tokenizers.GEMMA;
case 'Claude':
return tokenizers.CLAUDE;
case 'Cohere':
return tokenizers.COMMAND_R;
case 'Qwen':
return tokenizers.QWEN2;
default:
return tokenizers.OPENAI;
}

View File

@ -28,6 +28,8 @@ export const tokenizers = {
LLAMA3: 12,
GEMMA: 13,
JAMBA: 14,
QWEN2: 15,
COMMAND_R: 16,
BEST_MATCH: 99,
};
@ -105,6 +107,16 @@ const TOKENIZER_URLS = {
decode: '/api/tokenizers/jamba/decode',
count: '/api/tokenizers/jamba/encode',
},
[tokenizers.QWEN2]: {
encode: '/api/tokenizers/qwen2/encode',
decode: '/api/tokenizers/qwen2/decode',
count: '/api/tokenizers/qwen2/encode',
},
[tokenizers.COMMAND_R]: {
encode: '/api/tokenizers/command-r/encode',
decode: '/api/tokenizers/command-r/decode',
count: '/api/tokenizers/command-r/encode',
},
[tokenizers.API_TEXTGENERATIONWEBUI]: {
encode: '/api/tokenizers/remote/textgenerationwebui/encode',
count: '/api/tokenizers/remote/textgenerationwebui/encode',
@ -293,6 +305,12 @@ export function getTokenizerBestMatch(forApi) {
if (model.includes('jamba')) {
return tokenizers.JAMBA;
}
if (model.includes('command-r')) {
return tokenizers.COMMAND_R;
}
if (model.includes('qwen2')) {
return tokenizers.QWEN2;
}
}
return tokenizers.LLAMA;
@ -511,6 +529,8 @@ export function getTokenizerModel() {
const yiTokenizer = 'yi';
const gemmaTokenizer = 'gemma';
const jambaTokenizer = 'jamba';
const qwen2Tokenizer = 'qwen2';
const commandRTokenizer = 'command-r';
// Assuming no one would use it for different models.. right?
if (oai_settings.chat_completion_source == chat_completion_sources.SCALE) {
@ -558,6 +578,12 @@ export function getTokenizerModel() {
else if (model?.architecture?.tokenizer === 'Gemini') {
return gemmaTokenizer;
}
else if (model?.architecture?.tokenizer === 'Qwen') {
return qwen2Tokenizer;
}
else if (model?.architecture?.tokenizer === 'Cohere') {
return commandRTokenizer;
}
else if (oai_settings.openrouter_model.includes('gpt-4o')) {
return gpt4oTokenizer;
}
@ -581,6 +607,10 @@ export function getTokenizerModel() {
}
}
if (oai_settings.chat_completion_source == chat_completion_sources.COHERE) {
return commandRTokenizer;
}
if (oai_settings.chat_completion_source == chat_completion_sources.MAKERSUITE) {
return gemmaTokenizer;
}

View File

@ -4,13 +4,12 @@ const express = require('express');
const { SentencePieceProcessor } = require('@agnai/sentencepiece-js');
const tiktoken = require('tiktoken');
const { Tokenizer } = require('@agnai/web-tokenizers');
const { convertClaudePrompt, convertGooglePrompt } = require('../prompt-converters');
const { readSecret, SECRET_KEYS } = require('./secrets');
const { convertClaudePrompt } = require('../prompt-converters');
const { TEXTGEN_TYPES } = require('../constants');
const { jsonParser } = require('../express-common');
const { setAdditionalHeaders } = require('../additional-headers');
const API_MAKERSUITE = 'https://generativelanguage.googleapis.com';
const { getConfigValue, isValidUrl } = require('../util');
const writeFileAtomicSync = require('write-file-atomic').sync;
/**
* @typedef { (req: import('express').Request, res: import('express').Response) => Promise<any> } TokenizationHandler
@ -53,6 +52,65 @@ const TEXT_COMPLETION_MODELS = [
];
const CHARS_PER_TOKEN = 3.35;
const IS_DOWNLOAD_ALLOWED = getConfigValue('enableDownloadableTokenizers', true);
/**
* Gets a path to the tokenizer model. Downloads the model if it's a URL.
* @param {string} model Model URL or path
* @param {string|undefined} fallbackModel Fallback model path\
* @returns {Promise<string>} Path to the tokenizer model
*/
async function getPathToTokenizer(model, fallbackModel) {
if (!isValidUrl(model)) {
return model;
}
try {
const url = new URL(model);
if (!['https:', 'http:'].includes(url.protocol)) {
throw new Error('Invalid URL protocol');
}
const fileName = url.pathname.split('/').pop();
if (!fileName) {
throw new Error('Failed to extract the file name from the URL');
}
const CACHE_PATH = path.join(global.DATA_ROOT, '_cache');
if (!fs.existsSync(CACHE_PATH)) {
fs.mkdirSync(CACHE_PATH, { recursive: true });
}
const cachedFile = path.join(CACHE_PATH, fileName);
if (fs.existsSync(cachedFile)) {
return cachedFile;
}
if (!IS_DOWNLOAD_ALLOWED) {
throw new Error('Downloading tokenizers is disabled, the model is not cached');
}
console.log('Downloading tokenizer model:', model);
const response = await fetch(model);
if (!response.ok) {
throw new Error(`Failed to fetch the model: ${response.status} ${response.statusText}`);
}
const arrayBuffer = await response.arrayBuffer();
writeFileAtomicSync(cachedFile, Buffer.from(arrayBuffer));
return cachedFile;
} catch (error) {
const getLastSegment = str => str?.split('/')?.pop() || '';
if (fallbackModel) {
console.log(`Could not get a tokenizer from ${getLastSegment(model)}. Reason: ${error.message}. Using a fallback model: ${getLastSegment(fallbackModel)}.`);
return fallbackModel;
}
throw new Error(`Failed to instantiate a tokenizer and fallback is not provided. Reason: ${error.message}`);
}
}
/**
* Sentencepiece tokenizer for tokenizing text.
@ -66,13 +124,19 @@ class SentencePieceTokenizer {
* @type {string} Path to the tokenizer model
*/
#model;
/**
* @type {string|undefined} Path to the fallback model
*/
#fallbackModel;
/**
* Creates a new Sentencepiece tokenizer.
* @param {string} model Path to the tokenizer model
* @param {string} [fallbackModel] Path to the fallback model
*/
constructor(model) {
constructor(model, fallbackModel) {
this.#model = model;
this.#fallbackModel = fallbackModel;
}
/**
@ -85,9 +149,10 @@ class SentencePieceTokenizer {
}
try {
const pathToModel = await getPathToTokenizer(this.#model, this.#fallbackModel);
this.#instance = new SentencePieceProcessor();
await this.#instance.load(this.#model);
console.log('Instantiated the tokenizer for', path.parse(this.#model).name);
await this.#instance.load(pathToModel);
console.log('Instantiated the tokenizer for', path.parse(pathToModel).name);
return this.#instance;
} catch (error) {
console.error('Sentencepiece tokenizer failed to load: ' + this.#model, error);
@ -108,13 +173,19 @@ class WebTokenizer {
* @type {string} Path to the tokenizer model
*/
#model;
/**
* @type {string|undefined} Path to the fallback model
*/
#fallbackModel;
/**
* Creates a new Web tokenizer.
* @param {string} model Path to the tokenizer model
* @param {string} [fallbackModel] Path to the fallback model
*/
constructor(model) {
constructor(model, fallbackModel) {
this.#model = model;
this.#fallbackModel = fallbackModel;
}
/**
@ -127,9 +198,10 @@ class WebTokenizer {
}
try {
const arrayBuffer = fs.readFileSync(this.#model).buffer;
const pathToModel = await getPathToTokenizer(this.#model, this.#fallbackModel);
const arrayBuffer = fs.readFileSync(pathToModel).buffer;
this.#instance = await Tokenizer.fromJSON(arrayBuffer);
console.log('Instantiated the tokenizer for', path.parse(this.#model).name);
console.log('Instantiated the tokenizer for', path.parse(pathToModel).name);
return this.#instance;
} catch (error) {
console.error('Web tokenizer failed to load: ' + this.#model, error);
@ -147,6 +219,8 @@ const spp_gemma = new SentencePieceTokenizer('src/tokenizers/gemma.model');
const spp_jamba = new SentencePieceTokenizer('src/tokenizers/jamba.model');
const claude_tokenizer = new WebTokenizer('src/tokenizers/claude.json');
const llama3_tokenizer = new WebTokenizer('src/tokenizers/llama3.json');
const commandTokenizer = new WebTokenizer('https://github.com/SillyTavern/SillyTavern-Tokenizers/raw/main/command-r.json', 'src/tokenizers/llama3.json');
const qwen2Tokenizer = new WebTokenizer('https://github.com/SillyTavern/SillyTavern-Tokenizers/raw/main/qwen2.json', 'src/tokenizers/llama3.json');
const sentencepieceTokenizers = [
'llama',
@ -332,6 +406,14 @@ function getTokenizerModel(requestModel) {
return 'jamba';
}
if (requestModel.includes('qwen2')) {
return 'qwen2';
}
if (requestModel.includes('command-r')) {
return 'command-r';
}
// default
return 'gpt-3.5-turbo';
}
@ -557,6 +639,8 @@ router.post('/jamba/encode', jsonParser, createSentencepieceEncodingHandler(spp_
router.post('/gpt2/encode', jsonParser, createTiktokenEncodingHandler('gpt2'));
router.post('/claude/encode', jsonParser, createWebTokenizerEncodingHandler(claude_tokenizer));
router.post('/llama3/encode', jsonParser, createWebTokenizerEncodingHandler(llama3_tokenizer));
router.post('/qwen2/encode', jsonParser, createWebTokenizerEncodingHandler(qwen2Tokenizer));
router.post('/command-r/encode', jsonParser, createWebTokenizerEncodingHandler(commandTokenizer));
router.post('/llama/decode', jsonParser, createSentencepieceDecodingHandler(spp_llama));
router.post('/nerdstash/decode', jsonParser, createSentencepieceDecodingHandler(spp_nerd));
router.post('/nerdstash_v2/decode', jsonParser, createSentencepieceDecodingHandler(spp_nerd_v2));
@ -567,6 +651,8 @@ router.post('/jamba/decode', jsonParser, createSentencepieceDecodingHandler(spp_
router.post('/gpt2/decode', jsonParser, createTiktokenDecodingHandler('gpt2'));
router.post('/claude/decode', jsonParser, createWebTokenizerDecodingHandler(claude_tokenizer));
router.post('/llama3/decode', jsonParser, createWebTokenizerDecodingHandler(llama3_tokenizer));
router.post('/qwen2/decode', jsonParser, createWebTokenizerDecodingHandler(qwen2Tokenizer));
router.post('/command-r/decode', jsonParser, createWebTokenizerDecodingHandler(commandTokenizer));
router.post('/openai/encode', jsonParser, async function (req, res) {
try {
@ -607,6 +693,16 @@ router.post('/openai/encode', jsonParser, async function (req, res) {
return handler(req, res);
}
if (queryModel.includes('qwen2')) {
const handler = createWebTokenizerEncodingHandler(qwen2Tokenizer);
return handler(req, res);
}
if (queryModel.includes('command-r')) {
const handler = createWebTokenizerEncodingHandler(commandTokenizer);
return handler(req, res);
}
const model = getTokenizerModel(queryModel);
const handler = createTiktokenEncodingHandler(model);
return handler(req, res);
@ -655,6 +751,16 @@ router.post('/openai/decode', jsonParser, async function (req, res) {
return handler(req, res);
}
if (queryModel.includes('qwen2')) {
const handler = createWebTokenizerDecodingHandler(qwen2Tokenizer);
return handler(req, res);
}
if (queryModel.includes('command-r')) {
const handler = createWebTokenizerDecodingHandler(commandTokenizer);
return handler(req, res);
}
const model = getTokenizerModel(queryModel);
const handler = createTiktokenDecodingHandler(model);
return handler(req, res);
@ -711,6 +817,20 @@ router.post('/openai/count', jsonParser, async function (req, res) {
return res.send({ 'token_count': num_tokens });
}
if (model === 'qwen2') {
const instance = await qwen2Tokenizer.get();
if (!instance) throw new Error('Failed to load the Qwen2 tokenizer');
num_tokens = countWebTokenizerTokens(instance, req.body);
return res.send({ 'token_count': num_tokens });
}
if (model === 'command-r') {
const instance = await commandTokenizer.get();
if (!instance) throw new Error('Failed to load the Command-R tokenizer');
num_tokens = countWebTokenizerTokens(instance, req.body);
return res.send({ 'token_count': num_tokens });
}
const tokensPerName = queryModel.includes('gpt-3.5-turbo-0301') ? -1 : 1;
const tokensPerMessage = queryModel.includes('gpt-3.5-turbo-0301') ? 4 : 3;
const tokensPadding = 3;

View File

@ -647,6 +647,20 @@ function getSeparator(n) {
return '='.repeat(n);
}
/**
* Checks if the string is a valid URL.
* @param {string} url String to check
* @returns {boolean} If the URL is valid
*/
function isValidUrl(url) {
try {
new URL(url);
return true;
} catch (error) {
return false;
}
}
module.exports = {
getConfig,
getConfigValue,
@ -676,4 +690,5 @@ module.exports = {
makeHttp2Request,
removeColorFormatting,
getSeparator,
isValidUrl,
};