From 1c4bad35b2bc632e65bea81d86ec7364493471db Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Sat, 13 Apr 2024 21:05:31 +0300 Subject: [PATCH] #2085 Implement async token counting --- public/script.js | 4 +- public/scripts/RossAscends-mods.js | 30 +++-- public/scripts/openai.js | 3 +- public/scripts/power-user.js | 11 +- public/scripts/tokenizers.js | 207 +++++++++++++++++++++++++++-- 5 files changed, 229 insertions(+), 26 deletions(-) diff --git a/public/script.js b/public/script.js index a20a04ccb..5f591914a 100644 --- a/public/script.js +++ b/public/script.js @@ -82,6 +82,7 @@ import { flushEphemeralStoppingStrings, context_presets, resetMovableStyles, + forceCharacterEditorTokenize, } from './scripts/power-user.js'; import { @@ -5840,10 +5841,11 @@ function changeMainAPI() { if (main_api == 'koboldhorde') { getStatusHorde(); - getHordeModels(); + getHordeModels(true); } setupChatCompletionPromptManager(oai_settings); + forceCharacterEditorTokenize(); } //////////////////////////////////////////////////// diff --git a/public/scripts/RossAscends-mods.js b/public/scripts/RossAscends-mods.js index a6bebf3e4..57c135fb5 100644 --- a/public/scripts/RossAscends-mods.js +++ b/public/scripts/RossAscends-mods.js @@ -34,7 +34,7 @@ import { } from './secrets.js'; import { debounce, delay, getStringHash, isValidUrl } from './utils.js'; import { chat_completion_sources, oai_settings } from './openai.js'; -import { getTokenCount } from './tokenizers.js'; +import { getTokenCountAsync } from './tokenizers.js'; import { textgen_types, textgenerationwebui_settings as textgen_settings, getTextGenServer } from './textgen-settings.js'; import Bowser from '../lib/bowser.min.js'; @@ -51,6 +51,7 @@ var SelectedCharacterTab = document.getElementById('rm_button_selected_ch'); var connection_made = false; var retry_delay = 500; +let counterNonce = Date.now(); const observerConfig = { childList: true, subtree: true }; const countTokensDebounced = debounce(RA_CountCharTokens, 1000); @@ -202,24 +203,32 @@ $('#rm_ch_create_block').on('input', function () { countTokensDebounced(); }); //when any input is made to the advanced editing popup textareas $('#character_popup').on('input', function () { countTokensDebounced(); }); //function: -export function RA_CountCharTokens() { +export async function RA_CountCharTokens() { + counterNonce = Date.now(); + const counterNonceLocal = counterNonce; let total_tokens = 0; let permanent_tokens = 0; - $('[data-token-counter]').each(function () { - const counter = $(this); + const tokenCounters = document.querySelectorAll('[data-token-counter]'); + for (const tokenCounter of tokenCounters) { + if (counterNonceLocal !== counterNonce) { + return; + } + + const counter = $(tokenCounter); const input = $(document.getElementById(counter.data('token-counter'))); const isPermanent = counter.data('token-permanent') === true; const value = String(input.val()); if (input.length === 0) { counter.text('Invalid input reference'); - return; + continue; } if (!value) { + input.data('last-value-hash', ''); counter.text(0); - return; + continue; } const valueHash = getStringHash(value); @@ -230,13 +239,18 @@ export function RA_CountCharTokens() { } else { // We substitute macro for existing characters, but not for the character being created const valueToCount = menu_type === 'create' ? value : substituteParams(value); - const tokens = getTokenCount(valueToCount); + const tokens = await getTokenCountAsync(valueToCount); + + if (counterNonceLocal !== counterNonce) { + return; + } + counter.text(tokens); total_tokens += tokens; permanent_tokens += isPermanent ? tokens : 0; input.data('last-value-hash', valueHash); } - }); + } // Warn if total tokens exceeds the limit of half the max context const tokenLimit = Math.max(((main_api !== 'openai' ? max_context : oai_settings.openai_max_context) / 2), 1024); diff --git a/public/scripts/openai.js b/public/scripts/openai.js index 0e3a66c73..b8d7d192e 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -42,7 +42,7 @@ import { promptManagerDefaultPromptOrders, } from './PromptManager.js'; -import { getCustomStoppingStrings, persona_description_positions, power_user } from './power-user.js'; +import { forceCharacterEditorTokenize, getCustomStoppingStrings, persona_description_positions, power_user } from './power-user.js'; import { SECRET_KEYS, secret_state, writeSecret } from './secrets.js'; import { getEventSourceStream } from './sse-stream.js'; @@ -4429,6 +4429,7 @@ $(document).ready(async function () { toggleChatCompletionForms(); saveSettingsDebounced(); reconnectOpenAi(); + forceCharacterEditorTokenize(); eventSource.emit(event_types.CHATCOMPLETION_SOURCE_CHANGED, oai_settings.chat_completion_source); }); diff --git a/public/scripts/power-user.js b/public/scripts/power-user.js index 2bf3169b5..35f4984b7 100644 --- a/public/scripts/power-user.js +++ b/public/scripts/power-user.js @@ -2764,6 +2764,14 @@ export function getCustomStoppingStrings(limit = undefined) { return strings; } +export function forceCharacterEditorTokenize() { + $('[data-token-counter]').each(function () { + $(document.getElementById($(this).data('token-counter'))).data('last-value-hash', ''); + }); + $('#rm_ch_create_block').trigger('input'); + $('#character_popup').trigger('input'); +} + $(document).ready(() => { const adjustAutocompleteDebounced = debounce(() => { $('.ui-autocomplete-input').each(function () { @@ -3175,8 +3183,7 @@ $(document).ready(() => { saveSettingsDebounced(); // Trigger character editor re-tokenize - $('#rm_ch_create_block').trigger('input'); - $('#character_popup').trigger('input'); + forceCharacterEditorTokenize(); }); $('#send_on_enter').on('change', function () { diff --git a/public/scripts/tokenizers.js b/public/scripts/tokenizers.js index 03ae0b7f4..7e9fc7856 100644 --- a/public/scripts/tokenizers.js +++ b/public/scripts/tokenizers.js @@ -256,11 +256,93 @@ function callTokenizer(type, str) { } } +/** + * Calls the underlying tokenizer model to the token count for a string. + * @param {number} type Tokenizer type. + * @param {string} str String to tokenize. + * @returns {Promise} Token count. + */ +function callTokenizerAsync(type, str) { + return new Promise(resolve => { + if (type === tokenizers.NONE) { + return resolve(guesstimate(str)); + } + + switch (type) { + case tokenizers.API_CURRENT: + return callTokenizerAsync(currentRemoteTokenizerAPI(), str).then(resolve); + case tokenizers.API_KOBOLD: + return countTokensFromKoboldAPI(str, resolve); + case tokenizers.API_TEXTGENERATIONWEBUI: + return countTokensFromTextgenAPI(str, resolve); + default: { + const endpointUrl = TOKENIZER_URLS[type]?.count; + if (!endpointUrl) { + console.warn('Unknown tokenizer type', type); + return resolve(apiFailureTokenCount(str)); + } + return countTokensFromServer(endpointUrl, str, resolve); + } + } + }); +} + +/** + * Gets the token count for a string using the current model tokenizer. + * @param {string} str String to tokenize + * @param {number | undefined} padding Optional padding tokens. Defaults to 0. + * @returns {Promise} Token count. + */ +export async function getTokenCountAsync(str, padding = undefined) { + if (typeof str !== 'string' || !str?.length) { + return 0; + } + + let tokenizerType = power_user.tokenizer; + + if (main_api === 'openai') { + if (padding === power_user.token_padding) { + // For main "shadow" prompt building + tokenizerType = tokenizers.NONE; + } else { + // For extensions and WI + return counterWrapperOpenAIAsync(str); + } + } + + if (tokenizerType === tokenizers.BEST_MATCH) { + tokenizerType = getTokenizerBestMatch(main_api); + } + + if (padding === undefined) { + padding = 0; + } + + const cacheObject = getTokenCacheObject(); + const hash = getStringHash(str); + const cacheKey = `${tokenizerType}-${hash}+${padding}`; + + if (typeof cacheObject[cacheKey] === 'number') { + return cacheObject[cacheKey]; + } + + const result = (await callTokenizerAsync(tokenizerType, str)) + padding; + + if (isNaN(result)) { + console.warn('Token count calculation returned NaN'); + return 0; + } + + cacheObject[cacheKey] = result; + return result; +} + /** * Gets the token count for a string using the current model tokenizer. * @param {string} str String to tokenize * @param {number | undefined} padding Optional padding tokens. Defaults to 0. * @returns {number} Token count. + * @deprecated Use getTokenCountAsync instead. */ export function getTokenCount(str, padding = undefined) { if (typeof str !== 'string' || !str?.length) { @@ -310,12 +392,23 @@ export function getTokenCount(str, padding = undefined) { * Gets the token count for a string using the OpenAI tokenizer. * @param {string} text Text to tokenize. * @returns {number} Token count. + * @deprecated Use counterWrapperOpenAIAsync instead. */ function counterWrapperOpenAI(text) { const message = { role: 'system', content: text }; return countTokensOpenAI(message, true); } +/** + * Gets the token count for a string using the OpenAI tokenizer. + * @param {string} text Text to tokenize. + * @returns {Promise} Token count. + */ +function counterWrapperOpenAIAsync(text) { + const message = { role: 'system', content: text }; + return countTokensOpenAIAsync(message, true); +} + export function getTokenizerModel() { // OpenAI models always provide their own tokenizer if (oai_settings.chat_completion_source == chat_completion_sources.OPENAI) { @@ -410,6 +503,7 @@ export function getTokenizerModel() { /** * @param {any[] | Object} messages + * @deprecated Use countTokensOpenAIAsync instead. */ export function countTokensOpenAI(messages, full = false) { const shouldTokenizeAI21 = oai_settings.chat_completion_source === chat_completion_sources.AI21 && oai_settings.use_ai21_tokenizer; @@ -466,6 +560,66 @@ export function countTokensOpenAI(messages, full = false) { return token_count; } +/** + * Returns the token count for a message using the OpenAI tokenizer. + * @param {object[]|object} messages + * @param {boolean} full + * @returns {Promise} Token count. + */ +export async function countTokensOpenAIAsync(messages, full = false) { + const shouldTokenizeAI21 = oai_settings.chat_completion_source === chat_completion_sources.AI21 && oai_settings.use_ai21_tokenizer; + const shouldTokenizeGoogle = oai_settings.chat_completion_source === chat_completion_sources.MAKERSUITE && oai_settings.use_google_tokenizer; + let tokenizerEndpoint = ''; + if (shouldTokenizeAI21) { + tokenizerEndpoint = '/api/tokenizers/ai21/count'; + } else if (shouldTokenizeGoogle) { + tokenizerEndpoint = `/api/tokenizers/google/count?model=${getTokenizerModel()}`; + } else { + tokenizerEndpoint = `/api/tokenizers/openai/count?model=${getTokenizerModel()}`; + } + const cacheObject = getTokenCacheObject(); + + if (!Array.isArray(messages)) { + messages = [messages]; + } + + let token_count = -1; + + for (const message of messages) { + const model = getTokenizerModel(); + + if (model === 'claude' || shouldTokenizeAI21 || shouldTokenizeGoogle) { + full = true; + } + + const hash = getStringHash(JSON.stringify(message)); + const cacheKey = `${model}-${hash}`; + const cachedCount = cacheObject[cacheKey]; + + if (typeof cachedCount === 'number') { + token_count += cachedCount; + } + + else { + const data = await jQuery.ajax({ + async: true, + type: 'POST', // + url: tokenizerEndpoint, + data: JSON.stringify([message]), + dataType: 'json', + contentType: 'application/json', + }); + + token_count += Number(data.token_count); + cacheObject[cacheKey] = Number(data.token_count); + } + } + + if (!full) token_count -= 2; + + return token_count; +} + /** * Gets the token cache object for the current chat. * @returns {Object} Token cache object for the current chat. @@ -495,13 +649,15 @@ function getTokenCacheObject() { * Count tokens using the server API. * @param {string} endpoint API endpoint. * @param {string} str String to tokenize. + * @param {function} [resolve] Promise resolve function.s * @returns {number} Token count. */ -function countTokensFromServer(endpoint, str) { +function countTokensFromServer(endpoint, str, resolve) { + const isAsync = typeof resolve === 'function'; let tokenCount = 0; jQuery.ajax({ - async: false, + async: isAsync, type: 'POST', url: endpoint, data: JSON.stringify({ text: str }), @@ -513,6 +669,8 @@ function countTokensFromServer(endpoint, str) { } else { tokenCount = apiFailureTokenCount(str); } + + isAsync && resolve(tokenCount); }, }); @@ -522,13 +680,15 @@ function countTokensFromServer(endpoint, str) { /** * Count tokens using the AI provider's API. * @param {string} str String to tokenize. + * @param {function} [resolve] Promise resolve function. * @returns {number} Token count. */ -function countTokensFromKoboldAPI(str) { +function countTokensFromKoboldAPI(str, resolve) { + const isAsync = typeof resolve === 'function'; let tokenCount = 0; jQuery.ajax({ - async: false, + async: isAsync, type: 'POST', url: TOKENIZER_URLS[tokenizers.API_KOBOLD].count, data: JSON.stringify({ @@ -543,6 +703,8 @@ function countTokensFromKoboldAPI(str) { } else { tokenCount = apiFailureTokenCount(str); } + + isAsync && resolve(tokenCount); }, }); @@ -561,13 +723,15 @@ function getTextgenAPITokenizationParams(str) { /** * Count tokens using the AI provider's API. * @param {string} str String to tokenize. + * @param {function} [resolve] Promise resolve function. * @returns {number} Token count. */ -function countTokensFromTextgenAPI(str) { +function countTokensFromTextgenAPI(str, resolve) { + const isAsync = typeof resolve === 'function'; let tokenCount = 0; jQuery.ajax({ - async: false, + async: isAsync, type: 'POST', url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].count, data: JSON.stringify(getTextgenAPITokenizationParams(str)), @@ -579,6 +743,8 @@ function countTokensFromTextgenAPI(str) { } else { tokenCount = apiFailureTokenCount(str); } + + isAsync && resolve(tokenCount); }, }); @@ -605,12 +771,14 @@ function apiFailureTokenCount(str) { * Calls the underlying tokenizer model to encode a string to tokens. * @param {string} endpoint API endpoint. * @param {string} str String to tokenize. + * @param {function} [resolve] Promise resolve function. * @returns {number[]} Array of token ids. */ -function getTextTokensFromServer(endpoint, str) { +function getTextTokensFromServer(endpoint, str, resolve) { + const isAsync = typeof resolve === 'function'; let ids = []; jQuery.ajax({ - async: false, + async: isAsync, type: 'POST', url: endpoint, data: JSON.stringify({ text: str }), @@ -623,6 +791,8 @@ function getTextTokensFromServer(endpoint, str) { if (Array.isArray(data.chunks)) { Object.defineProperty(ids, 'chunks', { value: data.chunks }); } + + isAsync && resolve(ids); }, }); return ids; @@ -631,12 +801,14 @@ function getTextTokensFromServer(endpoint, str) { /** * Calls the AI provider's tokenize API to encode a string to tokens. * @param {string} str String to tokenize. + * @param {function} [resolve] Promise resolve function. * @returns {number[]} Array of token ids. */ -function getTextTokensFromTextgenAPI(str) { +function getTextTokensFromTextgenAPI(str, resolve) { + const isAsync = typeof resolve === 'function'; let ids = []; jQuery.ajax({ - async: false, + async: isAsync, type: 'POST', url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].encode, data: JSON.stringify(getTextgenAPITokenizationParams(str)), @@ -644,6 +816,7 @@ function getTextTokensFromTextgenAPI(str) { contentType: 'application/json', success: function (data) { ids = data.ids; + isAsync && resolve(ids); }, }); return ids; @@ -652,13 +825,15 @@ function getTextTokensFromTextgenAPI(str) { /** * Calls the AI provider's tokenize API to encode a string to tokens. * @param {string} str String to tokenize. + * @param {function} [resolve] Promise resolve function. * @returns {number[]} Array of token ids. */ -function getTextTokensFromKoboldAPI(str) { +function getTextTokensFromKoboldAPI(str, resolve) { + const isAsync = typeof resolve === 'function'; let ids = []; jQuery.ajax({ - async: false, + async: isAsync, type: 'POST', url: TOKENIZER_URLS[tokenizers.API_KOBOLD].encode, data: JSON.stringify({ @@ -669,6 +844,7 @@ function getTextTokensFromKoboldAPI(str) { contentType: 'application/json', success: function (data) { ids = data.ids; + isAsync && resolve(ids); }, }); @@ -679,13 +855,15 @@ function getTextTokensFromKoboldAPI(str) { * Calls the underlying tokenizer model to decode token ids to text. * @param {string} endpoint API endpoint. * @param {number[]} ids Array of token ids + * @param {function} [resolve] Promise resolve function. * @returns {({ text: string, chunks?: string[] })} Decoded token text as a single string and individual chunks (if available). */ -function decodeTextTokensFromServer(endpoint, ids) { +function decodeTextTokensFromServer(endpoint, ids, resolve) { + const isAsync = typeof resolve === 'function'; let text = ''; let chunks = []; jQuery.ajax({ - async: false, + async: isAsync, type: 'POST', url: endpoint, data: JSON.stringify({ ids: ids }), @@ -694,6 +872,7 @@ function decodeTextTokensFromServer(endpoint, ids) { success: function (data) { text = data.text; chunks = data.chunks; + isAsync && resolve({ text, chunks }); }, }); return { text, chunks };