Add Claude tokenizer

This commit is contained in:
Cohee
2023-06-26 13:36:56 +03:00
parent 7354003db1
commit 68f967ea78
7 changed files with 68 additions and 106 deletions

6
package-lock.json generated
View File

@ -10,6 +10,7 @@
"license": "AGPL-3.0", "license": "AGPL-3.0",
"dependencies": { "dependencies": {
"@dqbd/tiktoken": "^1.0.2", "@dqbd/tiktoken": "^1.0.2",
"@mlc-ai/web-tokenizers": "^0.1.0",
"axios": "^1.4.0", "axios": "^1.4.0",
"command-exists": "^1.2.9", "command-exists": "^1.2.9",
"compression": "^1", "compression": "^1",
@ -561,6 +562,11 @@
"integrity": "sha512-XPSJHWmi394fuUuzDnGz1wiKqWfo1yXecHQMRf2l6hztTO+nPru658AyDngaBe7isIxEkRsPR3FZh+s7iVa4Uw==", "integrity": "sha512-XPSJHWmi394fuUuzDnGz1wiKqWfo1yXecHQMRf2l6hztTO+nPru658AyDngaBe7isIxEkRsPR3FZh+s7iVa4Uw==",
"dev": true "dev": true
}, },
"node_modules/@mlc-ai/web-tokenizers": {
"version": "0.1.0",
"resolved": "https://registry.npmjs.org/@mlc-ai/web-tokenizers/-/web-tokenizers-0.1.0.tgz",
"integrity": "sha512-whiQ+40ohtAFoFOGcje1Io7BMr434Wh3hM3nBCWlJMpXxL5Rlig/AH9wjyUPsytKwWTEe7RoYPyXSbFw5Vs6Tw=="
},
"node_modules/@nodelib/fs.scandir": { "node_modules/@nodelib/fs.scandir": {
"version": "2.1.5", "version": "2.1.5",
"resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz",

View File

@ -1,6 +1,7 @@
{ {
"dependencies": { "dependencies": {
"@dqbd/tiktoken": "^1.0.2", "@dqbd/tiktoken": "^1.0.2",
"@mlc-ai/web-tokenizers": "^0.1.0",
"axios": "^1.4.0", "axios": "^1.4.0",
"command-exists": "^1.2.9", "command-exists": "^1.2.9",
"compression": "^1", "compression": "^1",

View File

@ -1,73 +0,0 @@
body {
margin: 0;
padding: 0;
width: 100%;
background-color: rgb(36, 37, 37);
background-repeat: no-repeat;
background-attachment: fixed;
background-size: cover;
font-family: "Noto Sans", "Noto Color Emoji", sans-serif;
font-size: 16px;
/*1rem*/
color: #999;
box-sizing: border-box;
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
/*z-index:0;*/
}
#main {
padding-top: 20px;
/*z-index:1;*/
}
#content {
margin: 0 auto;
max-width: 700px;
border: 1px solid #333;
padding: 20px;
border-radius: 20px;
background-color: rgba(0, 0, 0, 0.5);
line-height: 1.5rem;
box-shadow: 0 0 5px black;
/*z-index: 2;*/
}
code {
border: 1px solid #999;
background-color: rgba(0, 0, 0, 0.5);
padding: 5px;
border-radius: 5px;
display: block;
white-space: pre-line;
}
a {
color: orange;
text-decoration: none;
border-bottom: 1px dotted orange;
}
h2,
h3 {
color: #ccc;
}
hr {
border: 1px solid #999;
}
table {
width: 100%;
}
table,
th,
td {
border: 1px solid;
border-collapse: collapse;
}
table img {
max-width: 200px;
}

View File

@ -1,23 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<title>SillyTavern Documentation</title>
<link rel="stylesheet" href="/css/notes.css">
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link href="/webfonts/NotoSans/stylesheet.css" rel="stylesheet">
</head>
<body>
<div id="main">
<div id="content">
<h2>You weren't supposed to be able to get here, you know.</h1>
<h3>All help materials has been moved here:</h3>
<h3><a href="https://docs.sillytavern.app/">SillyTavern Documentation</a></h3>
</div>
</div>
</body>
</html>

View File

@ -927,16 +927,15 @@ function getTokenizerModel() {
return turboTokenizer; return turboTokenizer;
} }
else if (oai_settings.windowai_model.includes('claude')) { else if (oai_settings.windowai_model.includes('claude')) {
return turboTokenizer; return 'claude';
} }
else if (oai_settings.windowai_model.includes('GPT-NeoXT')) { else if (oai_settings.windowai_model.includes('GPT-NeoXT')) {
return 'gpt2'; return 'gpt2';
} }
} }
// We don't have a Claude tokenizer for JS yet. Turbo 3.5 should be able to handle this.
if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) { if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) {
return turboTokenizer; return 'claude';
} }
// Default to Turbo 3.5 // Default to Turbo 3.5

View File

@ -128,10 +128,13 @@ let response_getstatus;
const delay = ms => new Promise(resolve => setTimeout(resolve, ms)) const delay = ms => new Promise(resolve => setTimeout(resolve, ms))
const { SentencePieceProcessor, cleanText } = require("sentencepiece-js"); const { SentencePieceProcessor, cleanText } = require("sentencepiece-js");
const { Tokenizer } = require('@mlc-ai/web-tokenizers');
const CHARS_PER_TOKEN = 3.35;
let spp_llama; let spp_llama;
let spp_nerd; let spp_nerd;
let spp_nerd_v2; let spp_nerd_v2;
let claude_tokenizer;
async function loadSentencepieceTokenizer(modelPath) { async function loadSentencepieceTokenizer(modelPath) {
try { try {
@ -147,7 +150,7 @@ async function loadSentencepieceTokenizer(modelPath) {
async function countSentencepieceTokens(spp, text) { async function countSentencepieceTokens(spp, text) {
// Fallback to strlen estimation // Fallback to strlen estimation
if (!spp) { if (!spp) {
return Math.ceil(text.length / 3.35); return Math.ceil(text.length / CHARS_PER_TOKEN);
} }
let cleaned = cleanText(text); let cleaned = cleanText(text);
@ -156,9 +159,36 @@ async function countSentencepieceTokens(spp, text) {
return ids.length; return ids.length;
} }
async function loadClaudeTokenizer(modelPath) {
try {
const arrayBuffer = fs.readFileSync(modelPath).buffer;
const instance = await Tokenizer.fromJSON(arrayBuffer);
return instance;
} catch (error) {
console.error("Claude tokenizer failed to load: " + modelPath, error);
return null;
}
}
function countClaudeTokens(tokenizer, messages) {
const convertedPrompt = convertClaudePrompt(messages, false, false);
// Fallback to strlen estimation
if (!tokenizer) {
return Math.ceil(convertedPrompt.length / CHARS_PER_TOKEN);
}
const count = tokenizer.encode(convertedPrompt).length;
return count;
}
const tokenizersCache = {}; const tokenizersCache = {};
function getTokenizerModel(requestModel) { function getTokenizerModel(requestModel) {
if (requestModel.includes('claude')) {
return 'claude';
}
if (requestModel.includes('gpt-4-32k')) { if (requestModel.includes('gpt-4-32k')) {
return 'gpt-4-32k'; return 'gpt-4-32k';
} }
@ -2870,6 +2900,12 @@ app.post("/openai_bias", jsonParser, async function (request, response) {
let result = {}; let result = {};
const model = getTokenizerModel(String(request.query.model || '')); const model = getTokenizerModel(String(request.query.model || ''));
// no bias for claude
if (model == 'claude') {
return response.send(result);
}
const tokenizer = getTiktokenTokenizer(model); const tokenizer = getTiktokenTokenizer(model);
for (const entry of request.body) { for (const entry of request.body) {
@ -2942,7 +2978,7 @@ app.post("/deletepreset_openai", jsonParser, function (request, response) {
}); });
// Prompt Conversion script taken from RisuAI by @kwaroran (GPLv3). // Prompt Conversion script taken from RisuAI by @kwaroran (GPLv3).
function convertClaudePrompt(messages) { function convertClaudePrompt(messages, addHumanPrefix, addAssistantPostfix) {
// Claude doesn't support message names, so we'll just add them to the message content. // Claude doesn't support message names, so we'll just add them to the message content.
for (const message of messages) { for (const message of messages) {
if (message.name && message.role !== "system") { if (message.name && message.role !== "system") {
@ -2972,7 +3008,16 @@ function convertClaudePrompt(messages) {
break break
} }
return prefix + v.content; return prefix + v.content;
}).join('') + '\n\nAssistant: '; }).join('');
if (addHumanPrefix) {
requestPrompt = "\n\nHuman: " + requestPrompt;
}
if (addAssistantPostfix) {
requestPrompt = requestPrompt + '\n\nAssistant: ';
}
return requestPrompt; return requestPrompt;
} }
@ -2993,14 +3038,14 @@ async function sendClaudeRequest(request, response) {
controller.abort(); controller.abort();
}); });
const requestPrompt = convertClaudePrompt(request.body.messages); const requestPrompt = convertClaudePrompt(request.body.messages, true, true);
console.log('Claude request:', requestPrompt); console.log('Claude request:', requestPrompt);
const generateResponse = await fetch(api_url + '/complete', { const generateResponse = await fetch(api_url + '/complete', {
method: "POST", method: "POST",
signal: controller.signal, signal: controller.signal,
body: JSON.stringify({ body: JSON.stringify({
prompt: "\n\nHuman: " + requestPrompt, prompt: requestPrompt,
model: request.body.model, model: request.body.model,
max_tokens_to_sample: request.body.max_tokens, max_tokens_to_sample: request.body.max_tokens,
stop_sequences: ["\n\nHuman:", "\n\nSystem:", "\n\nAssistant:"], stop_sequences: ["\n\nHuman:", "\n\nSystem:", "\n\nAssistant:"],
@ -3166,15 +3211,20 @@ app.post("/generate_openai", jsonParser, function (request, response_generate_op
app.post("/tokenize_openai", jsonParser, function (request, response_tokenize_openai = response) { app.post("/tokenize_openai", jsonParser, function (request, response_tokenize_openai = response) {
if (!request.body) return response_tokenize_openai.sendStatus(400); if (!request.body) return response_tokenize_openai.sendStatus(400);
let num_tokens = 0;
const model = getTokenizerModel(String(request.query.model || '')); const model = getTokenizerModel(String(request.query.model || ''));
if (model == 'claude') {
num_tokens = countClaudeTokens(claude_tokenizer, request.body);
return response_tokenize_openai.send({ "token_count": num_tokens });
}
const tokensPerName = model.includes('gpt-4') ? 1 : -1; const tokensPerName = model.includes('gpt-4') ? 1 : -1;
const tokensPerMessage = model.includes('gpt-4') ? 3 : 4; const tokensPerMessage = model.includes('gpt-4') ? 3 : 4;
const tokensPadding = 3; const tokensPadding = 3;
const tokenizer = getTiktokenTokenizer(model); const tokenizer = getTiktokenTokenizer(model);
let num_tokens = 0;
for (const msg of request.body) { for (const msg of request.body) {
num_tokens += tokensPerMessage; num_tokens += tokensPerMessage;
for (const [key, value] of Object.entries(msg)) { for (const [key, value] of Object.entries(msg)) {
@ -3282,10 +3332,11 @@ const setupTasks = async function () {
// Colab users could run the embedded tool // Colab users could run the embedded tool
if (!is_colab) await convertWebp(); if (!is_colab) await convertWebp();
[spp_llama, spp_nerd, spp_nerd_v2] = await Promise.all([ [spp_llama, spp_nerd, spp_nerd_v2, claude_tokenizer] = await Promise.all([
loadSentencepieceTokenizer('src/sentencepiece/tokenizer.model'), loadSentencepieceTokenizer('src/sentencepiece/tokenizer.model'),
loadSentencepieceTokenizer('src/sentencepiece/nerdstash.model'), loadSentencepieceTokenizer('src/sentencepiece/nerdstash.model'),
loadSentencepieceTokenizer('src/sentencepiece/nerdstash_v2.model'), loadSentencepieceTokenizer('src/sentencepiece/nerdstash_v2.model'),
loadClaudeTokenizer('src/claude.json'),
]); ]);
console.log('Launching...'); console.log('Launching...');

1
src/claude.json Normal file

File diff suppressed because one or more lines are too long