Add llama 3 tokenizer

This commit is contained in:
Cohee 2024-05-03 23:59:39 +03:00
parent 7bc87b6e28
commit 7bfd666321
8 changed files with 143 additions and 21 deletions

View File

@ -3144,11 +3144,13 @@
<option value="0">None / Estimated</option>
<option value="1">GPT-2</option>
<!-- Option #2 was a legacy GPT-2/3 tokenizer -->
<option value="3">LLaMA</option>
<option value="3">Llama 1/2</option>
<option value="12">Llama 3</option>
<option value="4">NerdStash (NovelAI Clio)</option>
<option value="5">NerdStash v2 (NovelAI Kayra)</option>
<option value="7">Mistral</option>
<option value="8">Yi</option>
<option value="11">Claude 1/2</option>
<option value="6">API (WebUI / koboldcpp)</option>
</select>
</div>

View File

@ -440,6 +440,10 @@ export function getCurrentOpenRouterModelTokenizer() {
switch (model?.architecture?.tokenizer) {
case 'Llama2':
return tokenizers.LLAMA;
case 'Llama3':
return tokenizers.LLAMA3;
case 'Yi':
return tokenizers.YI;
case 'Mistral':
return tokenizers.MISTRAL;
default:

View File

@ -932,7 +932,7 @@ function toIntArray(string) {
return string.split(',').map(x => parseInt(x)).filter(x => !isNaN(x));
}
function getModel() {
export function getTextGenModel() {
switch (settings.type) {
case OOBA:
if (settings.custom_model) {
@ -974,7 +974,7 @@ export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate,
const canMultiSwipe = !isContinue && !isImpersonate && type !== 'quiet';
let params = {
'prompt': finalPrompt,
'model': getModel(),
'model': getTextGenModel(),
'max_new_tokens': maxTokens,
'max_tokens': maxTokens,
'logprobs': power_user.request_token_probabilities ? 10 : undefined,

View File

@ -4,7 +4,7 @@ import { chat_completion_sources, model_list, oai_settings } from './openai.js';
import { groups, selected_group } from './group-chats.js';
import { getStringHash } from './utils.js';
import { kai_flags } from './kai-settings.js';
import { textgen_types, textgenerationwebui_settings as textgen_settings, getTextGenServer } from './textgen-settings.js';
import { textgen_types, textgenerationwebui_settings as textgen_settings, getTextGenServer, getTextGenModel } from './textgen-settings.js';
import { getCurrentDreamGenModelTokenizer, getCurrentOpenRouterModelTokenizer, openRouterModels } from './textgen-models.js';
const { OOBA, TABBY, KOBOLDCPP, APHRODITE, LLAMACPP, OPENROUTER, DREAMGEN } = textgen_types;
@ -24,6 +24,8 @@ export const tokenizers = {
YI: 8,
API_TEXTGENERATIONWEBUI: 9,
API_KOBOLD: 10,
CLAUDE: 11,
LLAMA3: 12,
BEST_MATCH: 99,
};
@ -31,6 +33,7 @@ export const SENTENCEPIECE_TOKENIZERS = [
tokenizers.LLAMA,
tokenizers.MISTRAL,
tokenizers.YI,
tokenizers.LLAMA3,
// uncomment when NovelAI releases Kayra and Clio weights, lol
//tokenizers.NERD,
//tokenizers.NERD2,
@ -78,6 +81,16 @@ const TOKENIZER_URLS = {
decode: '/api/tokenizers/yi/decode',
count: '/api/tokenizers/yi/encode',
},
[tokenizers.CLAUDE]: {
encode: '/api/tokenizers/claude/encode',
decode: '/api/tokenizers/claude/decode',
count: '/api/tokenizers/claude/encode',
},
[tokenizers.LLAMA3]: {
encode: '/api/tokenizers/llama3/encode',
decode: '/api/tokenizers/llama3/decode',
count: '/api/tokenizers/llama3/encode',
},
[tokenizers.API_TEXTGENERATIONWEBUI]: {
encode: '/api/tokenizers/remote/textgenerationwebui/encode',
count: '/api/tokenizers/remote/textgenerationwebui/encode',
@ -211,6 +224,16 @@ export function getTokenizerBestMatch(forApi) {
}
}
if (forApi === 'textgenerationwebui') {
const model = String(getTextGenModel() || online_status).toLowerCase();
if (model.includes('llama3') || model.includes('llama-3')) {
return tokenizers.LLAMA3;
}
if (model.includes('mistral') || model.includes('mixtral')) {
return tokenizers.MISTRAL;
}
}
return tokenizers.LLAMA;
}
@ -421,6 +444,7 @@ export function getTokenizerModel() {
const gpt2Tokenizer = 'gpt2';
const claudeTokenizer = 'claude';
const llamaTokenizer = 'llama';
const llama3Tokenizer = 'llama3';
const mistralTokenizer = 'mistral';
const yiTokenizer = 'yi';
@ -458,6 +482,9 @@ export function getTokenizerModel() {
if (model?.architecture?.tokenizer === 'Llama2') {
return llamaTokenizer;
}
else if (model?.architecture?.tokenizer === 'Llama3') {
return llama3Tokenizer;
}
else if (model?.architecture?.tokenizer === 'Mistral') {
return mistralTokenizer;
}
@ -498,10 +525,13 @@ export function getTokenizerModel() {
}
if (oai_settings.chat_completion_source === chat_completion_sources.PERPLEXITY) {
if (oai_settings.perplexity_model.includes('llama-3') || oai_settings.perplexity_model.includes('llama3')) {
return llama3Tokenizer;
}
if (oai_settings.perplexity_model.includes('llama')) {
return llamaTokenizer;
}
if (oai_settings.perplexity_model.includes('mistral')) {
if (oai_settings.perplexity_model.includes('mistral') || oai_settings.perplexity_model.includes('mixtral')) {
return mistralTokenizer;
}
}

View File

@ -730,7 +730,11 @@ router.post('/bias', jsonParser, async function (request, response) {
if (sentencepieceTokenizers.includes(model)) {
const tokenizer = getSentencepiceTokenizer(model);
const instance = await tokenizer?.get();
encodeFunction = (text) => new Uint32Array(instance?.encodeIds(text));
if (!instance) {
console.warn('Tokenizer not initialized:', model);
return response.send({});
}
encodeFunction = (text) => new Uint32Array(instance.encodeIds(text));
} else {
const tokenizer = getTiktokenTokenizer(model);
encodeFunction = (tokenizer.encode.bind(tokenizer));

View File

@ -142,6 +142,7 @@ const spp_nerd_v2 = new SentencePieceTokenizer('src/tokenizers/nerdstash_v2.mode
const spp_mistral = new SentencePieceTokenizer('src/tokenizers/mistral.model');
const spp_yi = new SentencePieceTokenizer('src/tokenizers/yi.model');
const claude_tokenizer = new WebTokenizer('src/tokenizers/claude.json');
const llama3_tokenizer = new WebTokenizer('src/tokenizers/llama3.json');
const sentencepieceTokenizers = [
'llama',
@ -285,6 +286,10 @@ function getTokenizerModel(requestModel) {
return 'claude';
}
if (requestModel.includes('llama3') || requestModel.includes('llama-3')) {
return 'llama3';
}
if (requestModel.includes('llama')) {
return 'llama';
}
@ -313,12 +318,12 @@ function getTiktokenTokenizer(model) {
}
/**
* Counts the tokens for the given messages using the Claude tokenizer.
* Counts the tokens for the given messages using the WebTokenizer and Claude prompt conversion.
* @param {Tokenizer} tokenizer Web tokenizer
* @param {object[]} messages Array of messages
* @returns {number} Number of tokens
*/
function countClaudeTokens(tokenizer, messages) {
function countWebTokenizerTokens(tokenizer, messages) {
// Should be fine if we use the old conversion method instead of the messages API one i think?
const convertedPrompt = convertClaudePrompt(messages, false, '', false, false, '', false);
@ -449,6 +454,67 @@ function createTiktokenDecodingHandler(modelId) {
};
}
/**
* Creates an API handler for encoding WebTokenizer tokens.
* @param {WebTokenizer} tokenizer WebTokenizer instance
* @returns {TokenizationHandler} Handler function
*/
function createWebTokenizerEncodingHandler(tokenizer) {
/**
* Request handler for encoding WebTokenizer tokens.
* @param {import('express').Request} request
* @param {import('express').Response} response
*/
return async function (request, response) {
try {
if (!request.body) {
return response.sendStatus(400);
}
const text = request.body.text || '';
const instance = await tokenizer?.get();
if (!instance) throw new Error('Failed to load the Web tokenizer');
const tokens = Array.from(instance.encode(text));
const chunks = getWebTokenizersChunks(instance, tokens);
return response.send({ ids: tokens, count: tokens.length, chunks });
} catch (error) {
console.log(error);
return response.send({ ids: [], count: 0, chunks: [] });
}
};
}
/**
* Creates an API handler for decoding WebTokenizer tokens.
* @param {WebTokenizer} tokenizer WebTokenizer instance
* @returns {TokenizationHandler} Handler function
*/
function createWebTokenizerDecodingHandler(tokenizer) {
/**
* Request handler for decoding WebTokenizer tokens.
* @param {import('express').Request} request
* @param {import('express').Response} response
* @returns {Promise<any>}
*/
return async function (request, response) {
try {
if (!request.body) {
return response.sendStatus(400);
}
const ids = request.body.ids || [];
const instance = await tokenizer?.get();
if (!instance) throw new Error('Failed to load the Web tokenizer');
const chunks = getWebTokenizersChunks(instance, ids);
const text = instance.decode(new Int32Array(ids));
return response.send({ text, chunks });
} catch (error) {
console.log(error);
return response.send({ text: '', chunks: [] });
}
};
}
const router = express.Router();
router.post('/ai21/count', jsonParser, async function (req, res) {
@ -501,17 +567,26 @@ router.post('/nerdstash_v2/encode', jsonParser, createSentencepieceEncodingHandl
router.post('/mistral/encode', jsonParser, createSentencepieceEncodingHandler(spp_mistral));
router.post('/yi/encode', jsonParser, createSentencepieceEncodingHandler(spp_yi));
router.post('/gpt2/encode', jsonParser, createTiktokenEncodingHandler('gpt2'));
router.post('/claude/encode', jsonParser, createWebTokenizerEncodingHandler(claude_tokenizer));
router.post('/llama3/encode', jsonParser, createWebTokenizerEncodingHandler(llama3_tokenizer));
router.post('/llama/decode', jsonParser, createSentencepieceDecodingHandler(spp_llama));
router.post('/nerdstash/decode', jsonParser, createSentencepieceDecodingHandler(spp_nerd));
router.post('/nerdstash_v2/decode', jsonParser, createSentencepieceDecodingHandler(spp_nerd_v2));
router.post('/mistral/decode', jsonParser, createSentencepieceDecodingHandler(spp_mistral));
router.post('/yi/decode', jsonParser, createSentencepieceDecodingHandler(spp_yi));
router.post('/gpt2/decode', jsonParser, createTiktokenDecodingHandler('gpt2'));
router.post('/claude/decode', jsonParser, createWebTokenizerDecodingHandler(claude_tokenizer));
router.post('/llama3/decode', jsonParser, createWebTokenizerDecodingHandler(llama3_tokenizer));
router.post('/openai/encode', jsonParser, async function (req, res) {
try {
const queryModel = String(req.query.model || '');
if (queryModel.includes('llama3') || queryModel.includes('llama-3')) {
const handler = createWebTokenizerEncodingHandler(llama3_tokenizer);
return handler(req, res);
}
if (queryModel.includes('llama')) {
const handler = createSentencepieceEncodingHandler(spp_llama);
return handler(req, res);
@ -528,12 +603,8 @@ router.post('/openai/encode', jsonParser, async function (req, res) {
}
if (queryModel.includes('claude')) {
const text = req.body.text || '';
const instance = await claude_tokenizer.get();
if (!instance) throw new Error('Failed to load the Claude tokenizer');
const tokens = Object.values(instance.encode(text));
const chunks = getWebTokenizersChunks(instance, tokens);
return res.send({ ids: tokens, count: tokens.length, chunks });
const handler = createWebTokenizerEncodingHandler(claude_tokenizer);
return handler(req, res);
}
const model = getTokenizerModel(queryModel);
@ -549,6 +620,11 @@ router.post('/openai/decode', jsonParser, async function (req, res) {
try {
const queryModel = String(req.query.model || '');
if (queryModel.includes('llama3') || queryModel.includes('llama-3')) {
const handler = createWebTokenizerDecodingHandler(llama3_tokenizer);
return handler(req, res);
}
if (queryModel.includes('llama')) {
const handler = createSentencepieceDecodingHandler(spp_llama);
return handler(req, res);
@ -565,11 +641,8 @@ router.post('/openai/decode', jsonParser, async function (req, res) {
}
if (queryModel.includes('claude')) {
const ids = req.body.ids || [];
const instance = await claude_tokenizer.get();
if (!instance) throw new Error('Failed to load the Claude tokenizer');
const chunkText = instance.decode(new Int32Array(ids));
return res.send({ text: chunkText });
const handler = createWebTokenizerDecodingHandler(claude_tokenizer);
return handler(req, res);
}
const model = getTokenizerModel(queryModel);
@ -592,7 +665,14 @@ router.post('/openai/count', jsonParser, async function (req, res) {
if (model === 'claude') {
const instance = await claude_tokenizer.get();
if (!instance) throw new Error('Failed to load the Claude tokenizer');
num_tokens = countClaudeTokens(instance, req.body);
num_tokens = countWebTokenizerTokens(instance, req.body);
return res.send({ 'token_count': num_tokens });
}
if (model === 'llama3' || model === 'llama-3') {
const instance = await llama3_tokenizer.get();
if (!instance) throw new Error('Failed to load the Llama3 tokenizer');
num_tokens = countWebTokenizerTokens(instance, req.body);
return res.send({ 'token_count': num_tokens });
}
@ -755,7 +835,7 @@ module.exports = {
TEXT_COMPLETION_MODELS,
getTokenizerModel,
getTiktokenTokenizer,
countClaudeTokens,
countWebTokenizerTokens,
getSentencepiceTokenizer,
sentencepieceTokenizers,
router,

View File

@ -2,6 +2,7 @@ require('./polyfill.js');
/**
* Convert a prompt from the ChatML objects to the format used by Claude.
* Mainly deprecated. Only used for counting tokens.
* @param {object[]} messages Array of messages
* @param {boolean} addAssistantPostfix Add Assistant postfix.
* @param {string} addAssistantPrefill Add Assistant prefill after the assistant postfix.

File diff suppressed because one or more lines are too long