mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-06-05 21:59:27 +02:00
@ -4,13 +4,12 @@ const express = require('express');
|
||||
const { SentencePieceProcessor } = require('@agnai/sentencepiece-js');
|
||||
const tiktoken = require('tiktoken');
|
||||
const { Tokenizer } = require('@agnai/web-tokenizers');
|
||||
const { convertClaudePrompt, convertGooglePrompt } = require('../prompt-converters');
|
||||
const { readSecret, SECRET_KEYS } = require('./secrets');
|
||||
const { convertClaudePrompt } = require('../prompt-converters');
|
||||
const { TEXTGEN_TYPES } = require('../constants');
|
||||
const { jsonParser } = require('../express-common');
|
||||
const { setAdditionalHeaders } = require('../additional-headers');
|
||||
|
||||
const API_MAKERSUITE = 'https://generativelanguage.googleapis.com';
|
||||
const { getConfigValue, isValidUrl } = require('../util');
|
||||
const writeFileAtomicSync = require('write-file-atomic').sync;
|
||||
|
||||
/**
|
||||
* @typedef { (req: import('express').Request, res: import('express').Response) => Promise<any> } TokenizationHandler
|
||||
@ -53,6 +52,65 @@ const TEXT_COMPLETION_MODELS = [
|
||||
];
|
||||
|
||||
const CHARS_PER_TOKEN = 3.35;
|
||||
const IS_DOWNLOAD_ALLOWED = getConfigValue('enableDownloadableTokenizers', true);
|
||||
|
||||
/**
|
||||
* Gets a path to the tokenizer model. Downloads the model if it's a URL.
|
||||
* @param {string} model Model URL or path
|
||||
* @param {string|undefined} fallbackModel Fallback model path\
|
||||
* @returns {Promise<string>} Path to the tokenizer model
|
||||
*/
|
||||
async function getPathToTokenizer(model, fallbackModel) {
|
||||
if (!isValidUrl(model)) {
|
||||
return model;
|
||||
}
|
||||
|
||||
try {
|
||||
const url = new URL(model);
|
||||
|
||||
if (!['https:', 'http:'].includes(url.protocol)) {
|
||||
throw new Error('Invalid URL protocol');
|
||||
}
|
||||
|
||||
const fileName = url.pathname.split('/').pop();
|
||||
|
||||
if (!fileName) {
|
||||
throw new Error('Failed to extract the file name from the URL');
|
||||
}
|
||||
|
||||
const CACHE_PATH = path.join(global.DATA_ROOT, '_cache');
|
||||
if (!fs.existsSync(CACHE_PATH)) {
|
||||
fs.mkdirSync(CACHE_PATH, { recursive: true });
|
||||
}
|
||||
|
||||
const cachedFile = path.join(CACHE_PATH, fileName);
|
||||
if (fs.existsSync(cachedFile)) {
|
||||
return cachedFile;
|
||||
}
|
||||
|
||||
if (!IS_DOWNLOAD_ALLOWED) {
|
||||
throw new Error('Downloading tokenizers is disabled, the model is not cached');
|
||||
}
|
||||
|
||||
console.log('Downloading tokenizer model:', model);
|
||||
const response = await fetch(model);
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch the model: ${response.status} ${response.statusText}`);
|
||||
}
|
||||
|
||||
const arrayBuffer = await response.arrayBuffer();
|
||||
writeFileAtomicSync(cachedFile, Buffer.from(arrayBuffer));
|
||||
return cachedFile;
|
||||
} catch (error) {
|
||||
const getLastSegment = str => str?.split('/')?.pop() || '';
|
||||
if (fallbackModel) {
|
||||
console.log(`Could not get a tokenizer from ${getLastSegment(model)}. Reason: ${error.message}. Using a fallback model: ${getLastSegment(fallbackModel)}.`);
|
||||
return fallbackModel;
|
||||
}
|
||||
|
||||
throw new Error(`Failed to instantiate a tokenizer and fallback is not provided. Reason: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sentencepiece tokenizer for tokenizing text.
|
||||
@ -66,13 +124,19 @@ class SentencePieceTokenizer {
|
||||
* @type {string} Path to the tokenizer model
|
||||
*/
|
||||
#model;
|
||||
/**
|
||||
* @type {string|undefined} Path to the fallback model
|
||||
*/
|
||||
#fallbackModel;
|
||||
|
||||
/**
|
||||
* Creates a new Sentencepiece tokenizer.
|
||||
* @param {string} model Path to the tokenizer model
|
||||
* @param {string} [fallbackModel] Path to the fallback model
|
||||
*/
|
||||
constructor(model) {
|
||||
constructor(model, fallbackModel) {
|
||||
this.#model = model;
|
||||
this.#fallbackModel = fallbackModel;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -85,9 +149,10 @@ class SentencePieceTokenizer {
|
||||
}
|
||||
|
||||
try {
|
||||
const pathToModel = await getPathToTokenizer(this.#model, this.#fallbackModel);
|
||||
this.#instance = new SentencePieceProcessor();
|
||||
await this.#instance.load(this.#model);
|
||||
console.log('Instantiated the tokenizer for', path.parse(this.#model).name);
|
||||
await this.#instance.load(pathToModel);
|
||||
console.log('Instantiated the tokenizer for', path.parse(pathToModel).name);
|
||||
return this.#instance;
|
||||
} catch (error) {
|
||||
console.error('Sentencepiece tokenizer failed to load: ' + this.#model, error);
|
||||
@ -108,13 +173,19 @@ class WebTokenizer {
|
||||
* @type {string} Path to the tokenizer model
|
||||
*/
|
||||
#model;
|
||||
/**
|
||||
* @type {string|undefined} Path to the fallback model
|
||||
*/
|
||||
#fallbackModel;
|
||||
|
||||
/**
|
||||
* Creates a new Web tokenizer.
|
||||
* @param {string} model Path to the tokenizer model
|
||||
* @param {string} [fallbackModel] Path to the fallback model
|
||||
*/
|
||||
constructor(model) {
|
||||
constructor(model, fallbackModel) {
|
||||
this.#model = model;
|
||||
this.#fallbackModel = fallbackModel;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -127,9 +198,10 @@ class WebTokenizer {
|
||||
}
|
||||
|
||||
try {
|
||||
const arrayBuffer = fs.readFileSync(this.#model).buffer;
|
||||
const pathToModel = await getPathToTokenizer(this.#model, this.#fallbackModel);
|
||||
const arrayBuffer = fs.readFileSync(pathToModel).buffer;
|
||||
this.#instance = await Tokenizer.fromJSON(arrayBuffer);
|
||||
console.log('Instantiated the tokenizer for', path.parse(this.#model).name);
|
||||
console.log('Instantiated the tokenizer for', path.parse(pathToModel).name);
|
||||
return this.#instance;
|
||||
} catch (error) {
|
||||
console.error('Web tokenizer failed to load: ' + this.#model, error);
|
||||
@ -147,6 +219,8 @@ const spp_gemma = new SentencePieceTokenizer('src/tokenizers/gemma.model');
|
||||
const spp_jamba = new SentencePieceTokenizer('src/tokenizers/jamba.model');
|
||||
const claude_tokenizer = new WebTokenizer('src/tokenizers/claude.json');
|
||||
const llama3_tokenizer = new WebTokenizer('src/tokenizers/llama3.json');
|
||||
const commandTokenizer = new WebTokenizer('https://github.com/SillyTavern/SillyTavern-Tokenizers/raw/main/command-r.json', 'src/tokenizers/llama3.json');
|
||||
const qwen2Tokenizer = new WebTokenizer('https://github.com/SillyTavern/SillyTavern-Tokenizers/raw/main/qwen2.json', 'src/tokenizers/llama3.json');
|
||||
|
||||
const sentencepieceTokenizers = [
|
||||
'llama',
|
||||
@ -332,6 +406,14 @@ function getTokenizerModel(requestModel) {
|
||||
return 'jamba';
|
||||
}
|
||||
|
||||
if (requestModel.includes('qwen2')) {
|
||||
return 'qwen2';
|
||||
}
|
||||
|
||||
if (requestModel.includes('command-r')) {
|
||||
return 'command-r';
|
||||
}
|
||||
|
||||
// default
|
||||
return 'gpt-3.5-turbo';
|
||||
}
|
||||
@ -557,6 +639,8 @@ router.post('/jamba/encode', jsonParser, createSentencepieceEncodingHandler(spp_
|
||||
router.post('/gpt2/encode', jsonParser, createTiktokenEncodingHandler('gpt2'));
|
||||
router.post('/claude/encode', jsonParser, createWebTokenizerEncodingHandler(claude_tokenizer));
|
||||
router.post('/llama3/encode', jsonParser, createWebTokenizerEncodingHandler(llama3_tokenizer));
|
||||
router.post('/qwen2/encode', jsonParser, createWebTokenizerEncodingHandler(qwen2Tokenizer));
|
||||
router.post('/command-r/encode', jsonParser, createWebTokenizerEncodingHandler(commandTokenizer));
|
||||
router.post('/llama/decode', jsonParser, createSentencepieceDecodingHandler(spp_llama));
|
||||
router.post('/nerdstash/decode', jsonParser, createSentencepieceDecodingHandler(spp_nerd));
|
||||
router.post('/nerdstash_v2/decode', jsonParser, createSentencepieceDecodingHandler(spp_nerd_v2));
|
||||
@ -567,6 +651,8 @@ router.post('/jamba/decode', jsonParser, createSentencepieceDecodingHandler(spp_
|
||||
router.post('/gpt2/decode', jsonParser, createTiktokenDecodingHandler('gpt2'));
|
||||
router.post('/claude/decode', jsonParser, createWebTokenizerDecodingHandler(claude_tokenizer));
|
||||
router.post('/llama3/decode', jsonParser, createWebTokenizerDecodingHandler(llama3_tokenizer));
|
||||
router.post('/qwen2/decode', jsonParser, createWebTokenizerDecodingHandler(qwen2Tokenizer));
|
||||
router.post('/command-r/decode', jsonParser, createWebTokenizerDecodingHandler(commandTokenizer));
|
||||
|
||||
router.post('/openai/encode', jsonParser, async function (req, res) {
|
||||
try {
|
||||
@ -607,6 +693,16 @@ router.post('/openai/encode', jsonParser, async function (req, res) {
|
||||
return handler(req, res);
|
||||
}
|
||||
|
||||
if (queryModel.includes('qwen2')) {
|
||||
const handler = createWebTokenizerEncodingHandler(qwen2Tokenizer);
|
||||
return handler(req, res);
|
||||
}
|
||||
|
||||
if (queryModel.includes('command-r')) {
|
||||
const handler = createWebTokenizerEncodingHandler(commandTokenizer);
|
||||
return handler(req, res);
|
||||
}
|
||||
|
||||
const model = getTokenizerModel(queryModel);
|
||||
const handler = createTiktokenEncodingHandler(model);
|
||||
return handler(req, res);
|
||||
@ -655,6 +751,16 @@ router.post('/openai/decode', jsonParser, async function (req, res) {
|
||||
return handler(req, res);
|
||||
}
|
||||
|
||||
if (queryModel.includes('qwen2')) {
|
||||
const handler = createWebTokenizerDecodingHandler(qwen2Tokenizer);
|
||||
return handler(req, res);
|
||||
}
|
||||
|
||||
if (queryModel.includes('command-r')) {
|
||||
const handler = createWebTokenizerDecodingHandler(commandTokenizer);
|
||||
return handler(req, res);
|
||||
}
|
||||
|
||||
const model = getTokenizerModel(queryModel);
|
||||
const handler = createTiktokenDecodingHandler(model);
|
||||
return handler(req, res);
|
||||
@ -711,6 +817,20 @@ router.post('/openai/count', jsonParser, async function (req, res) {
|
||||
return res.send({ 'token_count': num_tokens });
|
||||
}
|
||||
|
||||
if (model === 'qwen2') {
|
||||
const instance = await qwen2Tokenizer.get();
|
||||
if (!instance) throw new Error('Failed to load the Qwen2 tokenizer');
|
||||
num_tokens = countWebTokenizerTokens(instance, req.body);
|
||||
return res.send({ 'token_count': num_tokens });
|
||||
}
|
||||
|
||||
if (model === 'command-r') {
|
||||
const instance = await commandTokenizer.get();
|
||||
if (!instance) throw new Error('Failed to load the Command-R tokenizer');
|
||||
num_tokens = countWebTokenizerTokens(instance, req.body);
|
||||
return res.send({ 'token_count': num_tokens });
|
||||
}
|
||||
|
||||
const tokensPerName = queryModel.includes('gpt-3.5-turbo-0301') ? -1 : 1;
|
||||
const tokensPerMessage = queryModel.includes('gpt-3.5-turbo-0301') ? 4 : 3;
|
||||
const tokensPadding = 3;
|
||||
|
Reference in New Issue
Block a user