From fedc3b887f8fb7b3dea02c115873a563a7542e3f Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Sun, 5 Nov 2023 21:54:19 +0200 Subject: [PATCH] Add llama2 tokenizer for OpenRouter models --- public/scripts/openai.js | 2 +- public/scripts/tokenizers.js | 10 +++- src/tokenizers.js | 97 +++++++++++++++++++++--------------- 3 files changed, 67 insertions(+), 42 deletions(-) diff --git a/public/scripts/openai.js b/public/scripts/openai.js index 9d6081e0a..c303237b1 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -153,7 +153,7 @@ const textCompletionModels = [ ]; let biasCache = undefined; -let model_list = []; +export let model_list = []; export const chat_completion_sources = { OPENAI: 'openai', diff --git a/public/scripts/tokenizers.js b/public/scripts/tokenizers.js index eca027519..e481850ac 100644 --- a/public/scripts/tokenizers.js +++ b/public/scripts/tokenizers.js @@ -1,6 +1,6 @@ import { characters, main_api, nai_settings, online_status, this_chid } from "../script.js"; import { power_user, registerDebugFunction } from "./power-user.js"; -import { chat_completion_sources, oai_settings } from "./openai.js"; +import { chat_completion_sources, model_list, oai_settings } from "./openai.js"; import { groups, selected_group } from "./group-chats.js"; import { getStringHash } from "./utils.js"; import { kai_flags } from "./kai-settings.js"; @@ -187,6 +187,7 @@ export function getTokenizerModel() { const gpt4Tokenizer = 'gpt-4'; const gpt2Tokenizer = 'gpt2'; const claudeTokenizer = 'claude'; + const llamaTokenizer = 'llama'; // Assuming no one would use it for different models.. right? if (oai_settings.chat_completion_source == chat_completion_sources.SCALE) { @@ -214,7 +215,12 @@ export function getTokenizerModel() { // And for OpenRouter (if not a site model, then it's impossible to determine the tokenizer) if (oai_settings.chat_completion_source == chat_completion_sources.OPENROUTER && oai_settings.openrouter_model) { - if (oai_settings.openrouter_model.includes('gpt-4')) { + const model = model_list.find(x => x.id === oai_settings.openrouter_model); + + if (model?.architecture?.tokenizer === 'Llama2') { + return llamaTokenizer; + } + else if (oai_settings.openrouter_model.includes('gpt-4')) { return gpt4Tokenizer; } else if (oai_settings.openrouter_model.includes('gpt-3.5-turbo-0301')) { diff --git a/src/tokenizers.js b/src/tokenizers.js index 7306c97b8..4e8b9d346 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -87,6 +87,10 @@ function getTokenizerModel(requestModel) { return 'claude'; } + if (requestModel.includes('llama')) { + return 'llama'; + } + if (requestModel.includes('gpt-4-32k')) { return 'gpt-4-32k'; } @@ -288,49 +292,64 @@ function registerEndpoints(app, jsonParser) { app.post("/api/decode/nerdstash_v2", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd_v2)); app.post("/api/decode/gpt2", jsonParser, createTiktokenDecodingHandler('gpt2')); - app.post("/api/tokenize/openai", jsonParser, function (req, res) { - if (!req.body) return res.sendStatus(400); + app.post("/api/tokenize/openai", jsonParser, async function (req, res) { + try { + if (!req.body) return res.sendStatus(400); - let num_tokens = 0; - const queryModel = String(req.query.model || ''); - const model = getTokenizerModel(queryModel); + let num_tokens = 0; + const queryModel = String(req.query.model || ''); + const model = getTokenizerModel(queryModel); - if (model == 'claude') { - num_tokens = countClaudeTokens(claude_tokenizer, req.body); - return res.send({ "token_count": num_tokens }); - } - - const tokensPerName = queryModel.includes('gpt-3.5-turbo-0301') ? -1 : 1; - const tokensPerMessage = queryModel.includes('gpt-3.5-turbo-0301') ? 4 : 3; - const tokensPadding = 3; - - const tokenizer = getTiktokenTokenizer(model); - - for (const msg of req.body) { - try { - num_tokens += tokensPerMessage; - for (const [key, value] of Object.entries(msg)) { - num_tokens += tokenizer.encode(value).length; - if (key == "name") { - num_tokens += tokensPerName; - } - } - } catch { - console.warn("Error tokenizing message:", msg); + if (model == 'claude') { + num_tokens = countClaudeTokens(claude_tokenizer, req.body); + return res.send({ "token_count": num_tokens }); } + + if (model == 'llama') { + const jsonBody = req.body.flatMap(x => Object.values(x)).join('\n\n'); + const llamaResult = await countSentencepieceTokens(spp_llama, jsonBody); + console.log('jsonBody', jsonBody, 'llamaResult', llamaResult); + num_tokens = llamaResult.count; + return res.send({ "token_count": num_tokens }); + } + + const tokensPerName = queryModel.includes('gpt-3.5-turbo-0301') ? -1 : 1; + const tokensPerMessage = queryModel.includes('gpt-3.5-turbo-0301') ? 4 : 3; + const tokensPadding = 3; + + const tokenizer = getTiktokenTokenizer(model); + + for (const msg of req.body) { + try { + num_tokens += tokensPerMessage; + for (const [key, value] of Object.entries(msg)) { + num_tokens += tokenizer.encode(value).length; + if (key == "name") { + num_tokens += tokensPerName; + } + } + } catch { + console.warn("Error tokenizing message:", msg); + } + } + num_tokens += tokensPadding; + + // NB: Since 2023-10-14, the GPT-3.5 Turbo 0301 model shoves in 7-9 extra tokens to every message. + // More details: https://community.openai.com/t/gpt-3-5-turbo-0301-showing-different-behavior-suddenly/431326/14 + if (queryModel.includes('gpt-3.5-turbo-0301')) { + num_tokens += 9; + } + + // not needed for cached tokenizers + //tokenizer.free(); + + res.send({ "token_count": num_tokens }); + } catch (error) { + console.error('An error counting tokens, using fallback estimation method', error); + const jsonBody = JSON.stringify(req.body); + const num_tokens = Math.ceil(jsonBody.length / CHARS_PER_TOKEN); + res.send({ "token_count": num_tokens }); } - num_tokens += tokensPadding; - - // NB: Since 2023-10-14, the GPT-3.5 Turbo 0301 model shoves in 7-9 extra tokens to every message. - // More details: https://community.openai.com/t/gpt-3-5-turbo-0301-showing-different-behavior-suddenly/431326/14 - if (queryModel.includes('gpt-3.5-turbo-0301')) { - num_tokens += 9; - } - - // not needed for cached tokenizers - //tokenizer.free(); - - res.send({ "token_count": num_tokens }); }); }