#2085 Implement async token counting

This commit is contained in:
Cohee 2024-04-13 21:05:31 +03:00
parent ba397dd2a8
commit 1c4bad35b2
5 changed files with 229 additions and 26 deletions

View File

@ -82,6 +82,7 @@ import {
flushEphemeralStoppingStrings,
context_presets,
resetMovableStyles,
forceCharacterEditorTokenize,
} from './scripts/power-user.js';
import {
@ -5840,10 +5841,11 @@ function changeMainAPI() {
if (main_api == 'koboldhorde') {
getStatusHorde();
getHordeModels();
getHordeModels(true);
}
setupChatCompletionPromptManager(oai_settings);
forceCharacterEditorTokenize();
}
////////////////////////////////////////////////////

View File

@ -34,7 +34,7 @@ import {
} from './secrets.js';
import { debounce, delay, getStringHash, isValidUrl } from './utils.js';
import { chat_completion_sources, oai_settings } from './openai.js';
import { getTokenCount } from './tokenizers.js';
import { getTokenCountAsync } from './tokenizers.js';
import { textgen_types, textgenerationwebui_settings as textgen_settings, getTextGenServer } from './textgen-settings.js';
import Bowser from '../lib/bowser.min.js';
@ -51,6 +51,7 @@ var SelectedCharacterTab = document.getElementById('rm_button_selected_ch');
var connection_made = false;
var retry_delay = 500;
let counterNonce = Date.now();
const observerConfig = { childList: true, subtree: true };
const countTokensDebounced = debounce(RA_CountCharTokens, 1000);
@ -202,24 +203,32 @@ $('#rm_ch_create_block').on('input', function () { countTokensDebounced(); });
//when any input is made to the advanced editing popup textareas
$('#character_popup').on('input', function () { countTokensDebounced(); });
//function:
export function RA_CountCharTokens() {
export async function RA_CountCharTokens() {
counterNonce = Date.now();
const counterNonceLocal = counterNonce;
let total_tokens = 0;
let permanent_tokens = 0;
$('[data-token-counter]').each(function () {
const counter = $(this);
const tokenCounters = document.querySelectorAll('[data-token-counter]');
for (const tokenCounter of tokenCounters) {
if (counterNonceLocal !== counterNonce) {
return;
}
const counter = $(tokenCounter);
const input = $(document.getElementById(counter.data('token-counter')));
const isPermanent = counter.data('token-permanent') === true;
const value = String(input.val());
if (input.length === 0) {
counter.text('Invalid input reference');
return;
continue;
}
if (!value) {
input.data('last-value-hash', '');
counter.text(0);
return;
continue;
}
const valueHash = getStringHash(value);
@ -230,13 +239,18 @@ export function RA_CountCharTokens() {
} else {
// We substitute macro for existing characters, but not for the character being created
const valueToCount = menu_type === 'create' ? value : substituteParams(value);
const tokens = getTokenCount(valueToCount);
const tokens = await getTokenCountAsync(valueToCount);
if (counterNonceLocal !== counterNonce) {
return;
}
counter.text(tokens);
total_tokens += tokens;
permanent_tokens += isPermanent ? tokens : 0;
input.data('last-value-hash', valueHash);
}
});
}
// Warn if total tokens exceeds the limit of half the max context
const tokenLimit = Math.max(((main_api !== 'openai' ? max_context : oai_settings.openai_max_context) / 2), 1024);

View File

@ -42,7 +42,7 @@ import {
promptManagerDefaultPromptOrders,
} from './PromptManager.js';
import { getCustomStoppingStrings, persona_description_positions, power_user } from './power-user.js';
import { forceCharacterEditorTokenize, getCustomStoppingStrings, persona_description_positions, power_user } from './power-user.js';
import { SECRET_KEYS, secret_state, writeSecret } from './secrets.js';
import { getEventSourceStream } from './sse-stream.js';
@ -4429,6 +4429,7 @@ $(document).ready(async function () {
toggleChatCompletionForms();
saveSettingsDebounced();
reconnectOpenAi();
forceCharacterEditorTokenize();
eventSource.emit(event_types.CHATCOMPLETION_SOURCE_CHANGED, oai_settings.chat_completion_source);
});

View File

@ -2764,6 +2764,14 @@ export function getCustomStoppingStrings(limit = undefined) {
return strings;
}
export function forceCharacterEditorTokenize() {
$('[data-token-counter]').each(function () {
$(document.getElementById($(this).data('token-counter'))).data('last-value-hash', '');
});
$('#rm_ch_create_block').trigger('input');
$('#character_popup').trigger('input');
}
$(document).ready(() => {
const adjustAutocompleteDebounced = debounce(() => {
$('.ui-autocomplete-input').each(function () {
@ -3175,8 +3183,7 @@ $(document).ready(() => {
saveSettingsDebounced();
// Trigger character editor re-tokenize
$('#rm_ch_create_block').trigger('input');
$('#character_popup').trigger('input');
forceCharacterEditorTokenize();
});
$('#send_on_enter').on('change', function () {

View File

@ -256,11 +256,93 @@ function callTokenizer(type, str) {
}
}
/**
* Calls the underlying tokenizer model to the token count for a string.
* @param {number} type Tokenizer type.
* @param {string} str String to tokenize.
* @returns {Promise<number>} Token count.
*/
function callTokenizerAsync(type, str) {
return new Promise(resolve => {
if (type === tokenizers.NONE) {
return resolve(guesstimate(str));
}
switch (type) {
case tokenizers.API_CURRENT:
return callTokenizerAsync(currentRemoteTokenizerAPI(), str).then(resolve);
case tokenizers.API_KOBOLD:
return countTokensFromKoboldAPI(str, resolve);
case tokenizers.API_TEXTGENERATIONWEBUI:
return countTokensFromTextgenAPI(str, resolve);
default: {
const endpointUrl = TOKENIZER_URLS[type]?.count;
if (!endpointUrl) {
console.warn('Unknown tokenizer type', type);
return resolve(apiFailureTokenCount(str));
}
return countTokensFromServer(endpointUrl, str, resolve);
}
}
});
}
/**
* Gets the token count for a string using the current model tokenizer.
* @param {string} str String to tokenize
* @param {number | undefined} padding Optional padding tokens. Defaults to 0.
* @returns {Promise<number>} Token count.
*/
export async function getTokenCountAsync(str, padding = undefined) {
if (typeof str !== 'string' || !str?.length) {
return 0;
}
let tokenizerType = power_user.tokenizer;
if (main_api === 'openai') {
if (padding === power_user.token_padding) {
// For main "shadow" prompt building
tokenizerType = tokenizers.NONE;
} else {
// For extensions and WI
return counterWrapperOpenAIAsync(str);
}
}
if (tokenizerType === tokenizers.BEST_MATCH) {
tokenizerType = getTokenizerBestMatch(main_api);
}
if (padding === undefined) {
padding = 0;
}
const cacheObject = getTokenCacheObject();
const hash = getStringHash(str);
const cacheKey = `${tokenizerType}-${hash}+${padding}`;
if (typeof cacheObject[cacheKey] === 'number') {
return cacheObject[cacheKey];
}
const result = (await callTokenizerAsync(tokenizerType, str)) + padding;
if (isNaN(result)) {
console.warn('Token count calculation returned NaN');
return 0;
}
cacheObject[cacheKey] = result;
return result;
}
/**
* Gets the token count for a string using the current model tokenizer.
* @param {string} str String to tokenize
* @param {number | undefined} padding Optional padding tokens. Defaults to 0.
* @returns {number} Token count.
* @deprecated Use getTokenCountAsync instead.
*/
export function getTokenCount(str, padding = undefined) {
if (typeof str !== 'string' || !str?.length) {
@ -310,12 +392,23 @@ export function getTokenCount(str, padding = undefined) {
* Gets the token count for a string using the OpenAI tokenizer.
* @param {string} text Text to tokenize.
* @returns {number} Token count.
* @deprecated Use counterWrapperOpenAIAsync instead.
*/
function counterWrapperOpenAI(text) {
const message = { role: 'system', content: text };
return countTokensOpenAI(message, true);
}
/**
* Gets the token count for a string using the OpenAI tokenizer.
* @param {string} text Text to tokenize.
* @returns {Promise<number>} Token count.
*/
function counterWrapperOpenAIAsync(text) {
const message = { role: 'system', content: text };
return countTokensOpenAIAsync(message, true);
}
export function getTokenizerModel() {
// OpenAI models always provide their own tokenizer
if (oai_settings.chat_completion_source == chat_completion_sources.OPENAI) {
@ -410,6 +503,7 @@ export function getTokenizerModel() {
/**
* @param {any[] | Object} messages
* @deprecated Use countTokensOpenAIAsync instead.
*/
export function countTokensOpenAI(messages, full = false) {
const shouldTokenizeAI21 = oai_settings.chat_completion_source === chat_completion_sources.AI21 && oai_settings.use_ai21_tokenizer;
@ -466,6 +560,66 @@ export function countTokensOpenAI(messages, full = false) {
return token_count;
}
/**
* Returns the token count for a message using the OpenAI tokenizer.
* @param {object[]|object} messages
* @param {boolean} full
* @returns {Promise<number>} Token count.
*/
export async function countTokensOpenAIAsync(messages, full = false) {
const shouldTokenizeAI21 = oai_settings.chat_completion_source === chat_completion_sources.AI21 && oai_settings.use_ai21_tokenizer;
const shouldTokenizeGoogle = oai_settings.chat_completion_source === chat_completion_sources.MAKERSUITE && oai_settings.use_google_tokenizer;
let tokenizerEndpoint = '';
if (shouldTokenizeAI21) {
tokenizerEndpoint = '/api/tokenizers/ai21/count';
} else if (shouldTokenizeGoogle) {
tokenizerEndpoint = `/api/tokenizers/google/count?model=${getTokenizerModel()}`;
} else {
tokenizerEndpoint = `/api/tokenizers/openai/count?model=${getTokenizerModel()}`;
}
const cacheObject = getTokenCacheObject();
if (!Array.isArray(messages)) {
messages = [messages];
}
let token_count = -1;
for (const message of messages) {
const model = getTokenizerModel();
if (model === 'claude' || shouldTokenizeAI21 || shouldTokenizeGoogle) {
full = true;
}
const hash = getStringHash(JSON.stringify(message));
const cacheKey = `${model}-${hash}`;
const cachedCount = cacheObject[cacheKey];
if (typeof cachedCount === 'number') {
token_count += cachedCount;
}
else {
const data = await jQuery.ajax({
async: true,
type: 'POST', //
url: tokenizerEndpoint,
data: JSON.stringify([message]),
dataType: 'json',
contentType: 'application/json',
});
token_count += Number(data.token_count);
cacheObject[cacheKey] = Number(data.token_count);
}
}
if (!full) token_count -= 2;
return token_count;
}
/**
* Gets the token cache object for the current chat.
* @returns {Object} Token cache object for the current chat.
@ -495,13 +649,15 @@ function getTokenCacheObject() {
* Count tokens using the server API.
* @param {string} endpoint API endpoint.
* @param {string} str String to tokenize.
* @param {function} [resolve] Promise resolve function.s
* @returns {number} Token count.
*/
function countTokensFromServer(endpoint, str) {
function countTokensFromServer(endpoint, str, resolve) {
const isAsync = typeof resolve === 'function';
let tokenCount = 0;
jQuery.ajax({
async: false,
async: isAsync,
type: 'POST',
url: endpoint,
data: JSON.stringify({ text: str }),
@ -513,6 +669,8 @@ function countTokensFromServer(endpoint, str) {
} else {
tokenCount = apiFailureTokenCount(str);
}
isAsync && resolve(tokenCount);
},
});
@ -522,13 +680,15 @@ function countTokensFromServer(endpoint, str) {
/**
* Count tokens using the AI provider's API.
* @param {string} str String to tokenize.
* @param {function} [resolve] Promise resolve function.
* @returns {number} Token count.
*/
function countTokensFromKoboldAPI(str) {
function countTokensFromKoboldAPI(str, resolve) {
const isAsync = typeof resolve === 'function';
let tokenCount = 0;
jQuery.ajax({
async: false,
async: isAsync,
type: 'POST',
url: TOKENIZER_URLS[tokenizers.API_KOBOLD].count,
data: JSON.stringify({
@ -543,6 +703,8 @@ function countTokensFromKoboldAPI(str) {
} else {
tokenCount = apiFailureTokenCount(str);
}
isAsync && resolve(tokenCount);
},
});
@ -561,13 +723,15 @@ function getTextgenAPITokenizationParams(str) {
/**
* Count tokens using the AI provider's API.
* @param {string} str String to tokenize.
* @param {function} [resolve] Promise resolve function.
* @returns {number} Token count.
*/
function countTokensFromTextgenAPI(str) {
function countTokensFromTextgenAPI(str, resolve) {
const isAsync = typeof resolve === 'function';
let tokenCount = 0;
jQuery.ajax({
async: false,
async: isAsync,
type: 'POST',
url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].count,
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
@ -579,6 +743,8 @@ function countTokensFromTextgenAPI(str) {
} else {
tokenCount = apiFailureTokenCount(str);
}
isAsync && resolve(tokenCount);
},
});
@ -605,12 +771,14 @@ function apiFailureTokenCount(str) {
* Calls the underlying tokenizer model to encode a string to tokens.
* @param {string} endpoint API endpoint.
* @param {string} str String to tokenize.
* @param {function} [resolve] Promise resolve function.
* @returns {number[]} Array of token ids.
*/
function getTextTokensFromServer(endpoint, str) {
function getTextTokensFromServer(endpoint, str, resolve) {
const isAsync = typeof resolve === 'function';
let ids = [];
jQuery.ajax({
async: false,
async: isAsync,
type: 'POST',
url: endpoint,
data: JSON.stringify({ text: str }),
@ -623,6 +791,8 @@ function getTextTokensFromServer(endpoint, str) {
if (Array.isArray(data.chunks)) {
Object.defineProperty(ids, 'chunks', { value: data.chunks });
}
isAsync && resolve(ids);
},
});
return ids;
@ -631,12 +801,14 @@ function getTextTokensFromServer(endpoint, str) {
/**
* Calls the AI provider's tokenize API to encode a string to tokens.
* @param {string} str String to tokenize.
* @param {function} [resolve] Promise resolve function.
* @returns {number[]} Array of token ids.
*/
function getTextTokensFromTextgenAPI(str) {
function getTextTokensFromTextgenAPI(str, resolve) {
const isAsync = typeof resolve === 'function';
let ids = [];
jQuery.ajax({
async: false,
async: isAsync,
type: 'POST',
url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].encode,
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
@ -644,6 +816,7 @@ function getTextTokensFromTextgenAPI(str) {
contentType: 'application/json',
success: function (data) {
ids = data.ids;
isAsync && resolve(ids);
},
});
return ids;
@ -652,13 +825,15 @@ function getTextTokensFromTextgenAPI(str) {
/**
* Calls the AI provider's tokenize API to encode a string to tokens.
* @param {string} str String to tokenize.
* @param {function} [resolve] Promise resolve function.
* @returns {number[]} Array of token ids.
*/
function getTextTokensFromKoboldAPI(str) {
function getTextTokensFromKoboldAPI(str, resolve) {
const isAsync = typeof resolve === 'function';
let ids = [];
jQuery.ajax({
async: false,
async: isAsync,
type: 'POST',
url: TOKENIZER_URLS[tokenizers.API_KOBOLD].encode,
data: JSON.stringify({
@ -669,6 +844,7 @@ function getTextTokensFromKoboldAPI(str) {
contentType: 'application/json',
success: function (data) {
ids = data.ids;
isAsync && resolve(ids);
},
});
@ -679,13 +855,15 @@ function getTextTokensFromKoboldAPI(str) {
* Calls the underlying tokenizer model to decode token ids to text.
* @param {string} endpoint API endpoint.
* @param {number[]} ids Array of token ids
* @param {function} [resolve] Promise resolve function.
* @returns {({ text: string, chunks?: string[] })} Decoded token text as a single string and individual chunks (if available).
*/
function decodeTextTokensFromServer(endpoint, ids) {
function decodeTextTokensFromServer(endpoint, ids, resolve) {
const isAsync = typeof resolve === 'function';
let text = '';
let chunks = [];
jQuery.ajax({
async: false,
async: isAsync,
type: 'POST',
url: endpoint,
data: JSON.stringify({ ids: ids }),
@ -694,6 +872,7 @@ function decodeTextTokensFromServer(endpoint, ids) {
success: function (data) {
text = data.text;
chunks = data.chunks;
isAsync && resolve({ text, chunks });
},
});
return { text, chunks };