Add Mistral tokenizer

This commit is contained in:
Cohee
2023-11-06 01:26:13 +02:00
parent c3479b23d9
commit f248367ca3
5 changed files with 36 additions and 5 deletions

View File

@ -16,6 +16,7 @@ export const tokenizers = {
NERD: 4,
NERD2: 5,
API: 6,
MISTRAL: 7,
BEST_MATCH: 99,
};
@ -105,6 +106,8 @@ function callTokenizer(type, str, padding) {
return countTokensRemote('/api/tokenize/nerdstash', str, padding);
case tokenizers.NERD2:
return countTokensRemote('/api/tokenize/nerdstash_v2', str, padding);
case tokenizers.MISTRAL:
return countTokensRemote('/api/tokenize/mistral', str, padding);
case tokenizers.API:
return countTokensRemote('/tokenize_via_api', str, padding);
default:
@ -185,6 +188,7 @@ export function getTokenizerModel() {
const gpt2Tokenizer = 'gpt2';
const claudeTokenizer = 'claude';
const llamaTokenizer = 'llama';
const mistralTokenizer = 'mistral';
// Assuming no one would use it for different models.. right?
if (oai_settings.chat_completion_source == chat_completion_sources.SCALE) {
@ -217,6 +221,9 @@ export function getTokenizerModel() {
if (model?.architecture?.tokenizer === 'Llama2') {
return llamaTokenizer;
}
else if (model?.architecture?.tokenizer === 'Mistral') {
return mistralTokenizer;
}
else if (oai_settings.openrouter_model.includes('gpt-4')) {
return gpt4Tokenizer;
}
@ -420,6 +427,8 @@ export function getTextTokens(tokenizerType, str) {
return getTextTokensRemote('/api/tokenize/nerdstash', str);
case tokenizers.NERD2:
return getTextTokensRemote('/api/tokenize/nerdstash_v2', str);
case tokenizers.MISTRAL:
return getTextTokensRemote('/api/tokenize/mistral', str);
case tokenizers.OPENAI:
const model = getTokenizerModel();
return getTextTokensRemote('/api/tokenize/openai-encode', str, model);
@ -444,6 +453,8 @@ export function decodeTextTokens(tokenizerType, ids) {
return decodeTextTokensRemote('/api/decode/nerdstash', ids);
case tokenizers.NERD2:
return decodeTextTokensRemote('/api/decode/nerdstash_v2', ids);
case tokenizers.MISTRAL:
return decodeTextTokensRemote('/api/decode/mistral', ids);
default:
console.warn("Calling decodeTextTokens with unsupported tokenizer type", tokenizerType);
return '';