Add llama 3 tokenizer
This commit is contained in:
parent
7bc87b6e28
commit
7bfd666321
|
@ -3144,11 +3144,13 @@
|
|||
<option value="0">None / Estimated</option>
|
||||
<option value="1">GPT-2</option>
|
||||
<!-- Option #2 was a legacy GPT-2/3 tokenizer -->
|
||||
<option value="3">LLaMA</option>
|
||||
<option value="3">Llama 1/2</option>
|
||||
<option value="12">Llama 3</option>
|
||||
<option value="4">NerdStash (NovelAI Clio)</option>
|
||||
<option value="5">NerdStash v2 (NovelAI Kayra)</option>
|
||||
<option value="7">Mistral</option>
|
||||
<option value="8">Yi</option>
|
||||
<option value="11">Claude 1/2</option>
|
||||
<option value="6">API (WebUI / koboldcpp)</option>
|
||||
</select>
|
||||
</div>
|
||||
|
|
|
@ -440,6 +440,10 @@ export function getCurrentOpenRouterModelTokenizer() {
|
|||
switch (model?.architecture?.tokenizer) {
|
||||
case 'Llama2':
|
||||
return tokenizers.LLAMA;
|
||||
case 'Llama3':
|
||||
return tokenizers.LLAMA3;
|
||||
case 'Yi':
|
||||
return tokenizers.YI;
|
||||
case 'Mistral':
|
||||
return tokenizers.MISTRAL;
|
||||
default:
|
||||
|
|
|
@ -932,7 +932,7 @@ function toIntArray(string) {
|
|||
return string.split(',').map(x => parseInt(x)).filter(x => !isNaN(x));
|
||||
}
|
||||
|
||||
function getModel() {
|
||||
export function getTextGenModel() {
|
||||
switch (settings.type) {
|
||||
case OOBA:
|
||||
if (settings.custom_model) {
|
||||
|
@ -974,7 +974,7 @@ export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate,
|
|||
const canMultiSwipe = !isContinue && !isImpersonate && type !== 'quiet';
|
||||
let params = {
|
||||
'prompt': finalPrompt,
|
||||
'model': getModel(),
|
||||
'model': getTextGenModel(),
|
||||
'max_new_tokens': maxTokens,
|
||||
'max_tokens': maxTokens,
|
||||
'logprobs': power_user.request_token_probabilities ? 10 : undefined,
|
||||
|
|
|
@ -4,7 +4,7 @@ import { chat_completion_sources, model_list, oai_settings } from './openai.js';
|
|||
import { groups, selected_group } from './group-chats.js';
|
||||
import { getStringHash } from './utils.js';
|
||||
import { kai_flags } from './kai-settings.js';
|
||||
import { textgen_types, textgenerationwebui_settings as textgen_settings, getTextGenServer } from './textgen-settings.js';
|
||||
import { textgen_types, textgenerationwebui_settings as textgen_settings, getTextGenServer, getTextGenModel } from './textgen-settings.js';
|
||||
import { getCurrentDreamGenModelTokenizer, getCurrentOpenRouterModelTokenizer, openRouterModels } from './textgen-models.js';
|
||||
|
||||
const { OOBA, TABBY, KOBOLDCPP, APHRODITE, LLAMACPP, OPENROUTER, DREAMGEN } = textgen_types;
|
||||
|
@ -24,6 +24,8 @@ export const tokenizers = {
|
|||
YI: 8,
|
||||
API_TEXTGENERATIONWEBUI: 9,
|
||||
API_KOBOLD: 10,
|
||||
CLAUDE: 11,
|
||||
LLAMA3: 12,
|
||||
BEST_MATCH: 99,
|
||||
};
|
||||
|
||||
|
@ -31,6 +33,7 @@ export const SENTENCEPIECE_TOKENIZERS = [
|
|||
tokenizers.LLAMA,
|
||||
tokenizers.MISTRAL,
|
||||
tokenizers.YI,
|
||||
tokenizers.LLAMA3,
|
||||
// uncomment when NovelAI releases Kayra and Clio weights, lol
|
||||
//tokenizers.NERD,
|
||||
//tokenizers.NERD2,
|
||||
|
@ -78,6 +81,16 @@ const TOKENIZER_URLS = {
|
|||
decode: '/api/tokenizers/yi/decode',
|
||||
count: '/api/tokenizers/yi/encode',
|
||||
},
|
||||
[tokenizers.CLAUDE]: {
|
||||
encode: '/api/tokenizers/claude/encode',
|
||||
decode: '/api/tokenizers/claude/decode',
|
||||
count: '/api/tokenizers/claude/encode',
|
||||
},
|
||||
[tokenizers.LLAMA3]: {
|
||||
encode: '/api/tokenizers/llama3/encode',
|
||||
decode: '/api/tokenizers/llama3/decode',
|
||||
count: '/api/tokenizers/llama3/encode',
|
||||
},
|
||||
[tokenizers.API_TEXTGENERATIONWEBUI]: {
|
||||
encode: '/api/tokenizers/remote/textgenerationwebui/encode',
|
||||
count: '/api/tokenizers/remote/textgenerationwebui/encode',
|
||||
|
@ -211,6 +224,16 @@ export function getTokenizerBestMatch(forApi) {
|
|||
}
|
||||
}
|
||||
|
||||
if (forApi === 'textgenerationwebui') {
|
||||
const model = String(getTextGenModel() || online_status).toLowerCase();
|
||||
if (model.includes('llama3') || model.includes('llama-3')) {
|
||||
return tokenizers.LLAMA3;
|
||||
}
|
||||
if (model.includes('mistral') || model.includes('mixtral')) {
|
||||
return tokenizers.MISTRAL;
|
||||
}
|
||||
}
|
||||
|
||||
return tokenizers.LLAMA;
|
||||
}
|
||||
|
||||
|
@ -421,6 +444,7 @@ export function getTokenizerModel() {
|
|||
const gpt2Tokenizer = 'gpt2';
|
||||
const claudeTokenizer = 'claude';
|
||||
const llamaTokenizer = 'llama';
|
||||
const llama3Tokenizer = 'llama3';
|
||||
const mistralTokenizer = 'mistral';
|
||||
const yiTokenizer = 'yi';
|
||||
|
||||
|
@ -458,6 +482,9 @@ export function getTokenizerModel() {
|
|||
if (model?.architecture?.tokenizer === 'Llama2') {
|
||||
return llamaTokenizer;
|
||||
}
|
||||
else if (model?.architecture?.tokenizer === 'Llama3') {
|
||||
return llama3Tokenizer;
|
||||
}
|
||||
else if (model?.architecture?.tokenizer === 'Mistral') {
|
||||
return mistralTokenizer;
|
||||
}
|
||||
|
@ -498,10 +525,13 @@ export function getTokenizerModel() {
|
|||
}
|
||||
|
||||
if (oai_settings.chat_completion_source === chat_completion_sources.PERPLEXITY) {
|
||||
if (oai_settings.perplexity_model.includes('llama-3') || oai_settings.perplexity_model.includes('llama3')) {
|
||||
return llama3Tokenizer;
|
||||
}
|
||||
if (oai_settings.perplexity_model.includes('llama')) {
|
||||
return llamaTokenizer;
|
||||
}
|
||||
if (oai_settings.perplexity_model.includes('mistral')) {
|
||||
if (oai_settings.perplexity_model.includes('mistral') || oai_settings.perplexity_model.includes('mixtral')) {
|
||||
return mistralTokenizer;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -730,7 +730,11 @@ router.post('/bias', jsonParser, async function (request, response) {
|
|||
if (sentencepieceTokenizers.includes(model)) {
|
||||
const tokenizer = getSentencepiceTokenizer(model);
|
||||
const instance = await tokenizer?.get();
|
||||
encodeFunction = (text) => new Uint32Array(instance?.encodeIds(text));
|
||||
if (!instance) {
|
||||
console.warn('Tokenizer not initialized:', model);
|
||||
return response.send({});
|
||||
}
|
||||
encodeFunction = (text) => new Uint32Array(instance.encodeIds(text));
|
||||
} else {
|
||||
const tokenizer = getTiktokenTokenizer(model);
|
||||
encodeFunction = (tokenizer.encode.bind(tokenizer));
|
||||
|
|
|
@ -142,6 +142,7 @@ const spp_nerd_v2 = new SentencePieceTokenizer('src/tokenizers/nerdstash_v2.mode
|
|||
const spp_mistral = new SentencePieceTokenizer('src/tokenizers/mistral.model');
|
||||
const spp_yi = new SentencePieceTokenizer('src/tokenizers/yi.model');
|
||||
const claude_tokenizer = new WebTokenizer('src/tokenizers/claude.json');
|
||||
const llama3_tokenizer = new WebTokenizer('src/tokenizers/llama3.json');
|
||||
|
||||
const sentencepieceTokenizers = [
|
||||
'llama',
|
||||
|
@ -285,6 +286,10 @@ function getTokenizerModel(requestModel) {
|
|||
return 'claude';
|
||||
}
|
||||
|
||||
if (requestModel.includes('llama3') || requestModel.includes('llama-3')) {
|
||||
return 'llama3';
|
||||
}
|
||||
|
||||
if (requestModel.includes('llama')) {
|
||||
return 'llama';
|
||||
}
|
||||
|
@ -313,12 +318,12 @@ function getTiktokenTokenizer(model) {
|
|||
}
|
||||
|
||||
/**
|
||||
* Counts the tokens for the given messages using the Claude tokenizer.
|
||||
* Counts the tokens for the given messages using the WebTokenizer and Claude prompt conversion.
|
||||
* @param {Tokenizer} tokenizer Web tokenizer
|
||||
* @param {object[]} messages Array of messages
|
||||
* @returns {number} Number of tokens
|
||||
*/
|
||||
function countClaudeTokens(tokenizer, messages) {
|
||||
function countWebTokenizerTokens(tokenizer, messages) {
|
||||
// Should be fine if we use the old conversion method instead of the messages API one i think?
|
||||
const convertedPrompt = convertClaudePrompt(messages, false, '', false, false, '', false);
|
||||
|
||||
|
@ -449,6 +454,67 @@ function createTiktokenDecodingHandler(modelId) {
|
|||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an API handler for encoding WebTokenizer tokens.
|
||||
* @param {WebTokenizer} tokenizer WebTokenizer instance
|
||||
* @returns {TokenizationHandler} Handler function
|
||||
*/
|
||||
function createWebTokenizerEncodingHandler(tokenizer) {
|
||||
/**
|
||||
* Request handler for encoding WebTokenizer tokens.
|
||||
* @param {import('express').Request} request
|
||||
* @param {import('express').Response} response
|
||||
*/
|
||||
return async function (request, response) {
|
||||
try {
|
||||
if (!request.body) {
|
||||
return response.sendStatus(400);
|
||||
}
|
||||
|
||||
const text = request.body.text || '';
|
||||
const instance = await tokenizer?.get();
|
||||
if (!instance) throw new Error('Failed to load the Web tokenizer');
|
||||
const tokens = Array.from(instance.encode(text));
|
||||
const chunks = getWebTokenizersChunks(instance, tokens);
|
||||
return response.send({ ids: tokens, count: tokens.length, chunks });
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
return response.send({ ids: [], count: 0, chunks: [] });
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an API handler for decoding WebTokenizer tokens.
|
||||
* @param {WebTokenizer} tokenizer WebTokenizer instance
|
||||
* @returns {TokenizationHandler} Handler function
|
||||
*/
|
||||
function createWebTokenizerDecodingHandler(tokenizer) {
|
||||
/**
|
||||
* Request handler for decoding WebTokenizer tokens.
|
||||
* @param {import('express').Request} request
|
||||
* @param {import('express').Response} response
|
||||
* @returns {Promise<any>}
|
||||
*/
|
||||
return async function (request, response) {
|
||||
try {
|
||||
if (!request.body) {
|
||||
return response.sendStatus(400);
|
||||
}
|
||||
|
||||
const ids = request.body.ids || [];
|
||||
const instance = await tokenizer?.get();
|
||||
if (!instance) throw new Error('Failed to load the Web tokenizer');
|
||||
const chunks = getWebTokenizersChunks(instance, ids);
|
||||
const text = instance.decode(new Int32Array(ids));
|
||||
return response.send({ text, chunks });
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
return response.send({ text: '', chunks: [] });
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.post('/ai21/count', jsonParser, async function (req, res) {
|
||||
|
@ -501,17 +567,26 @@ router.post('/nerdstash_v2/encode', jsonParser, createSentencepieceEncodingHandl
|
|||
router.post('/mistral/encode', jsonParser, createSentencepieceEncodingHandler(spp_mistral));
|
||||
router.post('/yi/encode', jsonParser, createSentencepieceEncodingHandler(spp_yi));
|
||||
router.post('/gpt2/encode', jsonParser, createTiktokenEncodingHandler('gpt2'));
|
||||
router.post('/claude/encode', jsonParser, createWebTokenizerEncodingHandler(claude_tokenizer));
|
||||
router.post('/llama3/encode', jsonParser, createWebTokenizerEncodingHandler(llama3_tokenizer));
|
||||
router.post('/llama/decode', jsonParser, createSentencepieceDecodingHandler(spp_llama));
|
||||
router.post('/nerdstash/decode', jsonParser, createSentencepieceDecodingHandler(spp_nerd));
|
||||
router.post('/nerdstash_v2/decode', jsonParser, createSentencepieceDecodingHandler(spp_nerd_v2));
|
||||
router.post('/mistral/decode', jsonParser, createSentencepieceDecodingHandler(spp_mistral));
|
||||
router.post('/yi/decode', jsonParser, createSentencepieceDecodingHandler(spp_yi));
|
||||
router.post('/gpt2/decode', jsonParser, createTiktokenDecodingHandler('gpt2'));
|
||||
router.post('/claude/decode', jsonParser, createWebTokenizerDecodingHandler(claude_tokenizer));
|
||||
router.post('/llama3/decode', jsonParser, createWebTokenizerDecodingHandler(llama3_tokenizer));
|
||||
|
||||
router.post('/openai/encode', jsonParser, async function (req, res) {
|
||||
try {
|
||||
const queryModel = String(req.query.model || '');
|
||||
|
||||
if (queryModel.includes('llama3') || queryModel.includes('llama-3')) {
|
||||
const handler = createWebTokenizerEncodingHandler(llama3_tokenizer);
|
||||
return handler(req, res);
|
||||
}
|
||||
|
||||
if (queryModel.includes('llama')) {
|
||||
const handler = createSentencepieceEncodingHandler(spp_llama);
|
||||
return handler(req, res);
|
||||
|
@ -528,12 +603,8 @@ router.post('/openai/encode', jsonParser, async function (req, res) {
|
|||
}
|
||||
|
||||
if (queryModel.includes('claude')) {
|
||||
const text = req.body.text || '';
|
||||
const instance = await claude_tokenizer.get();
|
||||
if (!instance) throw new Error('Failed to load the Claude tokenizer');
|
||||
const tokens = Object.values(instance.encode(text));
|
||||
const chunks = getWebTokenizersChunks(instance, tokens);
|
||||
return res.send({ ids: tokens, count: tokens.length, chunks });
|
||||
const handler = createWebTokenizerEncodingHandler(claude_tokenizer);
|
||||
return handler(req, res);
|
||||
}
|
||||
|
||||
const model = getTokenizerModel(queryModel);
|
||||
|
@ -549,6 +620,11 @@ router.post('/openai/decode', jsonParser, async function (req, res) {
|
|||
try {
|
||||
const queryModel = String(req.query.model || '');
|
||||
|
||||
if (queryModel.includes('llama3') || queryModel.includes('llama-3')) {
|
||||
const handler = createWebTokenizerDecodingHandler(llama3_tokenizer);
|
||||
return handler(req, res);
|
||||
}
|
||||
|
||||
if (queryModel.includes('llama')) {
|
||||
const handler = createSentencepieceDecodingHandler(spp_llama);
|
||||
return handler(req, res);
|
||||
|
@ -565,11 +641,8 @@ router.post('/openai/decode', jsonParser, async function (req, res) {
|
|||
}
|
||||
|
||||
if (queryModel.includes('claude')) {
|
||||
const ids = req.body.ids || [];
|
||||
const instance = await claude_tokenizer.get();
|
||||
if (!instance) throw new Error('Failed to load the Claude tokenizer');
|
||||
const chunkText = instance.decode(new Int32Array(ids));
|
||||
return res.send({ text: chunkText });
|
||||
const handler = createWebTokenizerDecodingHandler(claude_tokenizer);
|
||||
return handler(req, res);
|
||||
}
|
||||
|
||||
const model = getTokenizerModel(queryModel);
|
||||
|
@ -592,7 +665,14 @@ router.post('/openai/count', jsonParser, async function (req, res) {
|
|||
if (model === 'claude') {
|
||||
const instance = await claude_tokenizer.get();
|
||||
if (!instance) throw new Error('Failed to load the Claude tokenizer');
|
||||
num_tokens = countClaudeTokens(instance, req.body);
|
||||
num_tokens = countWebTokenizerTokens(instance, req.body);
|
||||
return res.send({ 'token_count': num_tokens });
|
||||
}
|
||||
|
||||
if (model === 'llama3' || model === 'llama-3') {
|
||||
const instance = await llama3_tokenizer.get();
|
||||
if (!instance) throw new Error('Failed to load the Llama3 tokenizer');
|
||||
num_tokens = countWebTokenizerTokens(instance, req.body);
|
||||
return res.send({ 'token_count': num_tokens });
|
||||
}
|
||||
|
||||
|
@ -755,7 +835,7 @@ module.exports = {
|
|||
TEXT_COMPLETION_MODELS,
|
||||
getTokenizerModel,
|
||||
getTiktokenTokenizer,
|
||||
countClaudeTokens,
|
||||
countWebTokenizerTokens,
|
||||
getSentencepiceTokenizer,
|
||||
sentencepieceTokenizers,
|
||||
router,
|
||||
|
|
|
@ -2,6 +2,7 @@ require('./polyfill.js');
|
|||
|
||||
/**
|
||||
* Convert a prompt from the ChatML objects to the format used by Claude.
|
||||
* Mainly deprecated. Only used for counting tokens.
|
||||
* @param {object[]} messages Array of messages
|
||||
* @param {boolean} addAssistantPostfix Add Assistant postfix.
|
||||
* @param {string} addAssistantPrefill Add Assistant prefill after the assistant postfix.
|
||||
|
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue