mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-04-12 10:01:08 +02:00
Add Mistral tokenizer
This commit is contained in:
parent
c3479b23d9
commit
f248367ca3
@ -2384,9 +2384,10 @@
|
|||||||
<option value="0">None / Estimated</option>
|
<option value="0">None / Estimated</option>
|
||||||
<option value="1">GPT-2</option>
|
<option value="1">GPT-2</option>
|
||||||
<!-- Option #2 was a legacy GPT-2/3 tokenizer -->
|
<!-- Option #2 was a legacy GPT-2/3 tokenizer -->
|
||||||
<option value="3">Sentencepiece (LLaMA)</option>
|
<option value="3">LLaMA</option>
|
||||||
<option value="4">NerdStash (NovelAI Clio)</option>
|
<option value="4">NerdStash (NovelAI Clio)</option>
|
||||||
<option value="5">NerdStash v2 (NovelAI Kayra)</option>
|
<option value="5">NerdStash v2 (NovelAI Kayra)</option>
|
||||||
|
<option value="7">Mistral</option>
|
||||||
<option value="6">API (WebUI / koboldcpp)</option>
|
<option value="6">API (WebUI / koboldcpp)</option>
|
||||||
</select>
|
</select>
|
||||||
</div>
|
</div>
|
||||||
@ -5044,4 +5045,4 @@
|
|||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
|
|
||||||
</html>
|
</html>
|
||||||
|
@ -16,6 +16,7 @@ export const tokenizers = {
|
|||||||
NERD: 4,
|
NERD: 4,
|
||||||
NERD2: 5,
|
NERD2: 5,
|
||||||
API: 6,
|
API: 6,
|
||||||
|
MISTRAL: 7,
|
||||||
BEST_MATCH: 99,
|
BEST_MATCH: 99,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -105,6 +106,8 @@ function callTokenizer(type, str, padding) {
|
|||||||
return countTokensRemote('/api/tokenize/nerdstash', str, padding);
|
return countTokensRemote('/api/tokenize/nerdstash', str, padding);
|
||||||
case tokenizers.NERD2:
|
case tokenizers.NERD2:
|
||||||
return countTokensRemote('/api/tokenize/nerdstash_v2', str, padding);
|
return countTokensRemote('/api/tokenize/nerdstash_v2', str, padding);
|
||||||
|
case tokenizers.MISTRAL:
|
||||||
|
return countTokensRemote('/api/tokenize/mistral', str, padding);
|
||||||
case tokenizers.API:
|
case tokenizers.API:
|
||||||
return countTokensRemote('/tokenize_via_api', str, padding);
|
return countTokensRemote('/tokenize_via_api', str, padding);
|
||||||
default:
|
default:
|
||||||
@ -185,6 +188,7 @@ export function getTokenizerModel() {
|
|||||||
const gpt2Tokenizer = 'gpt2';
|
const gpt2Tokenizer = 'gpt2';
|
||||||
const claudeTokenizer = 'claude';
|
const claudeTokenizer = 'claude';
|
||||||
const llamaTokenizer = 'llama';
|
const llamaTokenizer = 'llama';
|
||||||
|
const mistralTokenizer = 'mistral';
|
||||||
|
|
||||||
// Assuming no one would use it for different models.. right?
|
// Assuming no one would use it for different models.. right?
|
||||||
if (oai_settings.chat_completion_source == chat_completion_sources.SCALE) {
|
if (oai_settings.chat_completion_source == chat_completion_sources.SCALE) {
|
||||||
@ -217,6 +221,9 @@ export function getTokenizerModel() {
|
|||||||
if (model?.architecture?.tokenizer === 'Llama2') {
|
if (model?.architecture?.tokenizer === 'Llama2') {
|
||||||
return llamaTokenizer;
|
return llamaTokenizer;
|
||||||
}
|
}
|
||||||
|
else if (model?.architecture?.tokenizer === 'Mistral') {
|
||||||
|
return mistralTokenizer;
|
||||||
|
}
|
||||||
else if (oai_settings.openrouter_model.includes('gpt-4')) {
|
else if (oai_settings.openrouter_model.includes('gpt-4')) {
|
||||||
return gpt4Tokenizer;
|
return gpt4Tokenizer;
|
||||||
}
|
}
|
||||||
@ -420,6 +427,8 @@ export function getTextTokens(tokenizerType, str) {
|
|||||||
return getTextTokensRemote('/api/tokenize/nerdstash', str);
|
return getTextTokensRemote('/api/tokenize/nerdstash', str);
|
||||||
case tokenizers.NERD2:
|
case tokenizers.NERD2:
|
||||||
return getTextTokensRemote('/api/tokenize/nerdstash_v2', str);
|
return getTextTokensRemote('/api/tokenize/nerdstash_v2', str);
|
||||||
|
case tokenizers.MISTRAL:
|
||||||
|
return getTextTokensRemote('/api/tokenize/mistral', str);
|
||||||
case tokenizers.OPENAI:
|
case tokenizers.OPENAI:
|
||||||
const model = getTokenizerModel();
|
const model = getTokenizerModel();
|
||||||
return getTextTokensRemote('/api/tokenize/openai-encode', str, model);
|
return getTextTokensRemote('/api/tokenize/openai-encode', str, model);
|
||||||
@ -444,6 +453,8 @@ export function decodeTextTokens(tokenizerType, ids) {
|
|||||||
return decodeTextTokensRemote('/api/decode/nerdstash', ids);
|
return decodeTextTokensRemote('/api/decode/nerdstash', ids);
|
||||||
case tokenizers.NERD2:
|
case tokenizers.NERD2:
|
||||||
return decodeTextTokensRemote('/api/decode/nerdstash_v2', ids);
|
return decodeTextTokensRemote('/api/decode/nerdstash_v2', ids);
|
||||||
|
case tokenizers.MISTRAL:
|
||||||
|
return decodeTextTokensRemote('/api/decode/mistral', ids);
|
||||||
default:
|
default:
|
||||||
console.warn("Calling decodeTextTokens with unsupported tokenizer type", tokenizerType);
|
console.warn("Calling decodeTextTokens with unsupported tokenizer type", tokenizerType);
|
||||||
return '';
|
return '';
|
||||||
|
BIN
src/sentencepiece/mistral.model
Normal file
BIN
src/sentencepiece/mistral.model
Normal file
Binary file not shown.
@ -46,6 +46,7 @@ const CHARS_PER_TOKEN = 3.35;
|
|||||||
let spp_llama;
|
let spp_llama;
|
||||||
let spp_nerd;
|
let spp_nerd;
|
||||||
let spp_nerd_v2;
|
let spp_nerd_v2;
|
||||||
|
let spp_mistral;
|
||||||
let claude_tokenizer;
|
let claude_tokenizer;
|
||||||
|
|
||||||
async function loadSentencepieceTokenizer(modelPath) {
|
async function loadSentencepieceTokenizer(modelPath) {
|
||||||
@ -91,6 +92,10 @@ function getTokenizerModel(requestModel) {
|
|||||||
return 'llama';
|
return 'llama';
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (requestModel.includes('mistral')) {
|
||||||
|
return 'mistral';
|
||||||
|
}
|
||||||
|
|
||||||
if (requestModel.includes('gpt-4-32k')) {
|
if (requestModel.includes('gpt-4-32k')) {
|
||||||
return 'gpt-4-32k';
|
return 'gpt-4-32k';
|
||||||
}
|
}
|
||||||
@ -247,10 +252,11 @@ function createTiktokenDecodingHandler(modelId) {
|
|||||||
* @returns {Promise<void>} Promise that resolves when the tokenizers are loaded
|
* @returns {Promise<void>} Promise that resolves when the tokenizers are loaded
|
||||||
*/
|
*/
|
||||||
async function loadTokenizers() {
|
async function loadTokenizers() {
|
||||||
[spp_llama, spp_nerd, spp_nerd_v2, claude_tokenizer] = await Promise.all([
|
[spp_llama, spp_nerd, spp_nerd_v2, spp_mistral, claude_tokenizer] = await Promise.all([
|
||||||
loadSentencepieceTokenizer('src/sentencepiece/tokenizer.model'),
|
loadSentencepieceTokenizer('src/sentencepiece/llama.model'),
|
||||||
loadSentencepieceTokenizer('src/sentencepiece/nerdstash.model'),
|
loadSentencepieceTokenizer('src/sentencepiece/nerdstash.model'),
|
||||||
loadSentencepieceTokenizer('src/sentencepiece/nerdstash_v2.model'),
|
loadSentencepieceTokenizer('src/sentencepiece/nerdstash_v2.model'),
|
||||||
|
loadSentencepieceTokenizer('src/sentencepiece/mistral.model'),
|
||||||
loadClaudeTokenizer('src/claude.json'),
|
loadClaudeTokenizer('src/claude.json'),
|
||||||
]);
|
]);
|
||||||
}
|
}
|
||||||
@ -286,10 +292,12 @@ function registerEndpoints(app, jsonParser) {
|
|||||||
app.post("/api/tokenize/llama", jsonParser, createSentencepieceEncodingHandler(() => spp_llama));
|
app.post("/api/tokenize/llama", jsonParser, createSentencepieceEncodingHandler(() => spp_llama));
|
||||||
app.post("/api/tokenize/nerdstash", jsonParser, createSentencepieceEncodingHandler(() => spp_nerd));
|
app.post("/api/tokenize/nerdstash", jsonParser, createSentencepieceEncodingHandler(() => spp_nerd));
|
||||||
app.post("/api/tokenize/nerdstash_v2", jsonParser, createSentencepieceEncodingHandler(() => spp_nerd_v2));
|
app.post("/api/tokenize/nerdstash_v2", jsonParser, createSentencepieceEncodingHandler(() => spp_nerd_v2));
|
||||||
|
app.post("/api/tokenize/mistral", jsonParser, createSentencepieceEncodingHandler(() => spp_mistral));
|
||||||
app.post("/api/tokenize/gpt2", jsonParser, createTiktokenEncodingHandler('gpt2'));
|
app.post("/api/tokenize/gpt2", jsonParser, createTiktokenEncodingHandler('gpt2'));
|
||||||
app.post("/api/decode/llama", jsonParser, createSentencepieceDecodingHandler(() => spp_llama));
|
app.post("/api/decode/llama", jsonParser, createSentencepieceDecodingHandler(() => spp_llama));
|
||||||
app.post("/api/decode/nerdstash", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd));
|
app.post("/api/decode/nerdstash", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd));
|
||||||
app.post("/api/decode/nerdstash_v2", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd_v2));
|
app.post("/api/decode/nerdstash_v2", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd_v2));
|
||||||
|
app.post("/api/decode/mistral", jsonParser, createSentencepieceDecodingHandler(() => spp_mistral));
|
||||||
app.post("/api/decode/gpt2", jsonParser, createTiktokenDecodingHandler('gpt2'));
|
app.post("/api/decode/gpt2", jsonParser, createTiktokenDecodingHandler('gpt2'));
|
||||||
|
|
||||||
app.post("/api/tokenize/openai-encode", jsonParser, async function (req, res) {
|
app.post("/api/tokenize/openai-encode", jsonParser, async function (req, res) {
|
||||||
@ -301,6 +309,11 @@ function registerEndpoints(app, jsonParser) {
|
|||||||
return handler(req, res);
|
return handler(req, res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (queryModel.includes('mistral')) {
|
||||||
|
const handler = createSentencepieceEncodingHandler(() => spp_mistral);
|
||||||
|
return handler(req, res);
|
||||||
|
}
|
||||||
|
|
||||||
if (queryModel.includes('claude')) {
|
if (queryModel.includes('claude')) {
|
||||||
const text = req.body.text || '';
|
const text = req.body.text || '';
|
||||||
const tokens = Object.values(claude_tokenizer.encode(text));
|
const tokens = Object.values(claude_tokenizer.encode(text));
|
||||||
@ -332,11 +345,17 @@ function registerEndpoints(app, jsonParser) {
|
|||||||
if (model == 'llama') {
|
if (model == 'llama') {
|
||||||
const jsonBody = req.body.flatMap(x => Object.values(x)).join('\n\n');
|
const jsonBody = req.body.flatMap(x => Object.values(x)).join('\n\n');
|
||||||
const llamaResult = await countSentencepieceTokens(spp_llama, jsonBody);
|
const llamaResult = await countSentencepieceTokens(spp_llama, jsonBody);
|
||||||
// console.log('jsonBody', jsonBody, 'llamaResult', llamaResult);
|
|
||||||
num_tokens = llamaResult.count;
|
num_tokens = llamaResult.count;
|
||||||
return res.send({ "token_count": num_tokens });
|
return res.send({ "token_count": num_tokens });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (model == 'mistral') {
|
||||||
|
const jsonBody = req.body.flatMap(x => Object.values(x)).join('\n\n');
|
||||||
|
const mistralResult = await countSentencepieceTokens(spp_mistral, jsonBody);
|
||||||
|
num_tokens = mistralResult.count;
|
||||||
|
return res.send({ "token_count": num_tokens });
|
||||||
|
}
|
||||||
|
|
||||||
const tokensPerName = queryModel.includes('gpt-3.5-turbo-0301') ? -1 : 1;
|
const tokensPerName = queryModel.includes('gpt-3.5-turbo-0301') ? -1 : 1;
|
||||||
const tokensPerMessage = queryModel.includes('gpt-3.5-turbo-0301') ? 4 : 3;
|
const tokensPerMessage = queryModel.includes('gpt-3.5-turbo-0301') ? 4 : 3;
|
||||||
const tokensPadding = 3;
|
const tokensPadding = 3;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user