mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-06-05 21:59:27 +02:00
Add Claude tokenizer
This commit is contained in:
65
server.js
65
server.js
@@ -128,10 +128,13 @@ let response_getstatus;
|
||||
const delay = ms => new Promise(resolve => setTimeout(resolve, ms))
|
||||
|
||||
const { SentencePieceProcessor, cleanText } = require("sentencepiece-js");
|
||||
const { Tokenizer } = require('@mlc-ai/web-tokenizers');
|
||||
const CHARS_PER_TOKEN = 3.35;
|
||||
|
||||
let spp_llama;
|
||||
let spp_nerd;
|
||||
let spp_nerd_v2;
|
||||
let claude_tokenizer;
|
||||
|
||||
async function loadSentencepieceTokenizer(modelPath) {
|
||||
try {
|
||||
@@ -147,7 +150,7 @@ async function loadSentencepieceTokenizer(modelPath) {
|
||||
async function countSentencepieceTokens(spp, text) {
|
||||
// Fallback to strlen estimation
|
||||
if (!spp) {
|
||||
return Math.ceil(text.length / 3.35);
|
||||
return Math.ceil(text.length / CHARS_PER_TOKEN);
|
||||
}
|
||||
|
||||
let cleaned = cleanText(text);
|
||||
@@ -156,9 +159,36 @@ async function countSentencepieceTokens(spp, text) {
|
||||
return ids.length;
|
||||
}
|
||||
|
||||
async function loadClaudeTokenizer(modelPath) {
|
||||
try {
|
||||
const arrayBuffer = fs.readFileSync(modelPath).buffer;
|
||||
const instance = await Tokenizer.fromJSON(arrayBuffer);
|
||||
return instance;
|
||||
} catch (error) {
|
||||
console.error("Claude tokenizer failed to load: " + modelPath, error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function countClaudeTokens(tokenizer, messages) {
|
||||
const convertedPrompt = convertClaudePrompt(messages, false, false);
|
||||
|
||||
// Fallback to strlen estimation
|
||||
if (!tokenizer) {
|
||||
return Math.ceil(convertedPrompt.length / CHARS_PER_TOKEN);
|
||||
}
|
||||
|
||||
const count = tokenizer.encode(convertedPrompt).length;
|
||||
return count;
|
||||
}
|
||||
|
||||
const tokenizersCache = {};
|
||||
|
||||
function getTokenizerModel(requestModel) {
|
||||
if (requestModel.includes('claude')) {
|
||||
return 'claude';
|
||||
}
|
||||
|
||||
if (requestModel.includes('gpt-4-32k')) {
|
||||
return 'gpt-4-32k';
|
||||
}
|
||||
@@ -2870,6 +2900,12 @@ app.post("/openai_bias", jsonParser, async function (request, response) {
|
||||
let result = {};
|
||||
|
||||
const model = getTokenizerModel(String(request.query.model || ''));
|
||||
|
||||
// no bias for claude
|
||||
if (model == 'claude') {
|
||||
return response.send(result);
|
||||
}
|
||||
|
||||
const tokenizer = getTiktokenTokenizer(model);
|
||||
|
||||
for (const entry of request.body) {
|
||||
@@ -2942,7 +2978,7 @@ app.post("/deletepreset_openai", jsonParser, function (request, response) {
|
||||
});
|
||||
|
||||
// Prompt Conversion script taken from RisuAI by @kwaroran (GPLv3).
|
||||
function convertClaudePrompt(messages) {
|
||||
function convertClaudePrompt(messages, addHumanPrefix, addAssistantPostfix) {
|
||||
// Claude doesn't support message names, so we'll just add them to the message content.
|
||||
for (const message of messages) {
|
||||
if (message.name && message.role !== "system") {
|
||||
@@ -2972,7 +3008,16 @@ function convertClaudePrompt(messages) {
|
||||
break
|
||||
}
|
||||
return prefix + v.content;
|
||||
}).join('') + '\n\nAssistant: ';
|
||||
}).join('');
|
||||
|
||||
if (addHumanPrefix) {
|
||||
requestPrompt = "\n\nHuman: " + requestPrompt;
|
||||
}
|
||||
|
||||
if (addAssistantPostfix) {
|
||||
requestPrompt = requestPrompt + '\n\nAssistant: ';
|
||||
}
|
||||
|
||||
return requestPrompt;
|
||||
}
|
||||
|
||||
@@ -2993,14 +3038,14 @@ async function sendClaudeRequest(request, response) {
|
||||
controller.abort();
|
||||
});
|
||||
|
||||
const requestPrompt = convertClaudePrompt(request.body.messages);
|
||||
const requestPrompt = convertClaudePrompt(request.body.messages, true, true);
|
||||
console.log('Claude request:', requestPrompt);
|
||||
|
||||
const generateResponse = await fetch(api_url + '/complete', {
|
||||
method: "POST",
|
||||
signal: controller.signal,
|
||||
body: JSON.stringify({
|
||||
prompt: "\n\nHuman: " + requestPrompt,
|
||||
prompt: requestPrompt,
|
||||
model: request.body.model,
|
||||
max_tokens_to_sample: request.body.max_tokens,
|
||||
stop_sequences: ["\n\nHuman:", "\n\nSystem:", "\n\nAssistant:"],
|
||||
@@ -3166,15 +3211,20 @@ app.post("/generate_openai", jsonParser, function (request, response_generate_op
|
||||
app.post("/tokenize_openai", jsonParser, function (request, response_tokenize_openai = response) {
|
||||
if (!request.body) return response_tokenize_openai.sendStatus(400);
|
||||
|
||||
let num_tokens = 0;
|
||||
const model = getTokenizerModel(String(request.query.model || ''));
|
||||
|
||||
if (model == 'claude') {
|
||||
num_tokens = countClaudeTokens(claude_tokenizer, request.body);
|
||||
return response_tokenize_openai.send({ "token_count": num_tokens });
|
||||
}
|
||||
|
||||
const tokensPerName = model.includes('gpt-4') ? 1 : -1;
|
||||
const tokensPerMessage = model.includes('gpt-4') ? 3 : 4;
|
||||
const tokensPadding = 3;
|
||||
|
||||
const tokenizer = getTiktokenTokenizer(model);
|
||||
|
||||
let num_tokens = 0;
|
||||
for (const msg of request.body) {
|
||||
num_tokens += tokensPerMessage;
|
||||
for (const [key, value] of Object.entries(msg)) {
|
||||
@@ -3282,10 +3332,11 @@ const setupTasks = async function () {
|
||||
// Colab users could run the embedded tool
|
||||
if (!is_colab) await convertWebp();
|
||||
|
||||
[spp_llama, spp_nerd, spp_nerd_v2] = await Promise.all([
|
||||
[spp_llama, spp_nerd, spp_nerd_v2, claude_tokenizer] = await Promise.all([
|
||||
loadSentencepieceTokenizer('src/sentencepiece/tokenizer.model'),
|
||||
loadSentencepieceTokenizer('src/sentencepiece/nerdstash.model'),
|
||||
loadSentencepieceTokenizer('src/sentencepiece/nerdstash_v2.model'),
|
||||
loadClaudeTokenizer('src/claude.json'),
|
||||
]);
|
||||
|
||||
console.log('Launching...');
|
||||
|
Reference in New Issue
Block a user