diff --git a/public/scripts/tokenizers.js b/public/scripts/tokenizers.js index b72b672a8..c67e531a5 100644 --- a/public/scripts/tokenizers.js +++ b/public/scripts/tokenizers.js @@ -173,7 +173,7 @@ function callTokenizer(type, str, padding) { case tokenizers.YI: return countTokensFromServer('/api/tokenizers/yi/encode', str, padding); case tokenizers.API: - return countTokensFromServer('/api/tokenizers/remote/encode', str, padding); + return countTokensFromRemoteAPI('/api/tokenizers/remote/encode', str, padding); default: console.warn('Unknown tokenizer type', type); return callTokenizer(tokenizers.NONE, str, padding); @@ -392,6 +392,12 @@ function getTokenCacheObject() { } function getServerTokenizationParams(str) { + return { + text: str, + }; +} + +function getRemoteAPITokenizationParams(str) { return { text: str, main_api, @@ -404,7 +410,7 @@ function getServerTokenizationParams(str) { } /** - * Counts token using the server API. + * Count tokens using the server API. * @param {string} endpoint API endpoint. * @param {string} str String to tokenize. * @param {number} padding Number of padding tokens. @@ -424,18 +430,7 @@ function countTokensFromServer(endpoint, str, padding) { if (typeof data.count === 'number') { tokenCount = data.count; } else { - tokenCount = guesstimate(str); - console.error('Error counting tokens'); - - if (!sessionStorage.getItem(TOKENIZER_WARNING_KEY)) { - toastr.warning( - 'Your selected API doesn\'t support the tokenization endpoint. Using estimated counts.', - 'Error counting tokens', - { timeOut: 10000, preventDuplicates: true }, - ); - - sessionStorage.setItem(TOKENIZER_WARNING_KEY, String(true)); - } + tokenCount = apiFailureTokenCount(str); } }, }); @@ -443,6 +438,51 @@ function countTokensFromServer(endpoint, str, padding) { return tokenCount + padding; } +/** + * Count tokens using the AI provider's API. + * @param {string} endpoint API endpoint. + * @param {string} str String to tokenize. + * @param {number} padding Number of padding tokens. + * @returns {number} Token count with padding. + */ +function countTokensFromRemoteAPI(endpoint, str, padding) { + let tokenCount = 0; + + jQuery.ajax({ + async: false, + type: 'POST', + url: endpoint, + data: JSON.stringify(getRemoteAPITokenizationParams(str)), + dataType: 'json', + contentType: 'application/json', + success: function (data) { + if (typeof data.count === 'number') { + tokenCount = data.count; + } else { + tokenCount = apiFailureTokenCount(str); + } + }, + }); + + return tokenCount + padding; +} + +function apiFailureTokenCount(str) { + console.error('Error counting tokens'); + + if (!sessionStorage.getItem(TOKENIZER_WARNING_KEY)) { + toastr.warning( + 'Your selected API doesn\'t support the tokenization endpoint. Using estimated counts.', + 'Error counting tokens', + { timeOut: 10000, preventDuplicates: true }, + ); + + sessionStorage.setItem(TOKENIZER_WARNING_KEY, String(true)); + } + + return guesstimate(str); +} + /** * Calls the underlying tokenizer model to encode a string to tokens. * @param {string} endpoint API endpoint. @@ -475,6 +515,29 @@ function getTextTokensFromServer(endpoint, str, model = '') { return ids; } +/** + * Calls the AI provider's tokenize API to encode a string to tokens. + * @param {string} endpoint API endpoint. + * @param {string} str String to tokenize. + * @param {string} model Tokenizer model. + * @returns {number[]} Array of token ids. + */ +function getTextTokensFromRemoteAPI(endpoint, str, model = '') { + let ids = []; + jQuery.ajax({ + async: false, + type: 'POST', + url: endpoint, + data: JSON.stringify(getRemoteAPITokenizationParams(str)), + dataType: 'json', + contentType: 'application/json', + success: function (data) { + ids = data.ids; + }, + }); + return ids; +} + /** * Calls the underlying tokenizer model to decode token ids to text. * @param {string} endpoint API endpoint. @@ -525,7 +588,7 @@ export function getTextTokens(tokenizerType, str) { return getTextTokensFromServer('/api/tokenizers/openai/encode', str, model); } case tokenizers.API: - return getTextTokensFromServer('/api/tokenizers/remote/encode', str); + return getTextTokensFromRemoteAPI('/api/tokenizers/remote/encode', str); default: console.warn('Calling getTextTokens with unsupported tokenizer type', tokenizerType); return [];