From ab52af4fb5cc4ee27a0146f42d612c7ba9e42f33 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Thu, 24 Aug 2023 20:19:57 +0300 Subject: [PATCH] Add support for Koboldcpp tokenization endpoint --- public/index.html | 2 +- public/scripts/tokenizers.js | 37 ++++++++++++++++++++++++++++++++++-- server.js | 14 +++++++++++--- 3 files changed, 47 insertions(+), 6 deletions(-) diff --git a/public/index.html b/public/index.html index b4726d0d8..6afcb8388 100644 --- a/public/index.html +++ b/public/index.html @@ -2245,7 +2245,7 @@ - +
diff --git a/public/scripts/tokenizers.js b/public/scripts/tokenizers.js index 8f9b82405..4900e30c8 100644 --- a/public/scripts/tokenizers.js +++ b/public/scripts/tokenizers.js @@ -24,6 +24,15 @@ const gpt3 = new GPT3BrowserTokenizer({ type: 'gpt3' }); let tokenCache = {}; +/** + * Guesstimates the token count for a string. + * @param {string} str String to tokenize. + * @returns {number} Token count. + */ +export function guesstimate(str) { + return Math.ceil(str.length / CHARACTERS_PER_TOKEN_RATIO); +} + async function loadTokenCache() { try { console.debug('Chat Completions: loading token cache') @@ -89,7 +98,7 @@ export function getTokenCount(str, padding = undefined) { function calculate(type) { switch (type) { case tokenizers.NONE: - return Math.ceil(str.length / CHARACTERS_PER_TOKEN_RATIO) + padding; + return guesstimate(str) + padding; case tokenizers.GPT3: return gpt3.encode(str).bpe.length + padding; case tokenizers.CLASSIC: @@ -291,8 +300,16 @@ function getTokenCacheObject() { return tokenCache[String(chatId)]; } +/** + * Counts token using the remote server API. + * @param {string} endpoint API endpoint. + * @param {string} str String to tokenize. + * @param {number} padding Number of padding tokens. + * @returns {number} Token count with padding. + */ function countTokensRemote(endpoint, str, padding) { let tokenCount = 0; + jQuery.ajax({ async: false, type: 'POST', @@ -301,9 +318,25 @@ function countTokensRemote(endpoint, str, padding) { dataType: "json", contentType: "application/json", success: function (data) { - tokenCount = data.count; + if (typeof data.count === 'number') { + tokenCount = data.count; + } else { + tokenCount = guesstimate(str); + console.error("Error counting tokens"); + + if (!sessionStorage.getItem('tokenizationWarningShown')) { + toastr.warning( + "Your selected API doesn't support the tokenization endpoint. Using estimated counts.", + "Error counting tokens", + { timeOut: 10000, preventDuplicates: true }, + ); + + sessionStorage.setItem('tokenizationWarningShown', String(true)); + } + } } }); + return tokenCount + padding; } diff --git a/server.js b/server.js index b46880ccb..50e4dcfb3 100644 --- a/server.js +++ b/server.js @@ -3842,11 +3842,19 @@ app.post("/tokenize_via_api", jsonParser, async function (request, response) { if (main_api == 'textgenerationwebui' && request.body.use_mancer) { args.headers = Object.assign(args.headers, get_mancer_headers()); + const data = await postAsync(api_server + "/v1/token-count", args); + return response.send({ count: data['results'][0]['tokens'] }); } - const data = await postAsync(api_server + "/v1/token-count", args); - console.log(data); - return response.send({ count: data['results'][0]['tokens'] }); + else if (main_api == 'kobold') { + const data = await postAsync(api_server + "/extra/tokencount", args); + const count = data['value']; + return response.send({ count: count }); + } + + else { + return response.send({ error: true }); + } } catch (error) { console.log(error); return response.send({ error: true });