mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-06-05 21:59:27 +02:00
#2085 Implement async token counting
This commit is contained in:
@ -256,11 +256,93 @@ function callTokenizer(type, str) {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calls the underlying tokenizer model to the token count for a string.
|
||||
* @param {number} type Tokenizer type.
|
||||
* @param {string} str String to tokenize.
|
||||
* @returns {Promise<number>} Token count.
|
||||
*/
|
||||
function callTokenizerAsync(type, str) {
|
||||
return new Promise(resolve => {
|
||||
if (type === tokenizers.NONE) {
|
||||
return resolve(guesstimate(str));
|
||||
}
|
||||
|
||||
switch (type) {
|
||||
case tokenizers.API_CURRENT:
|
||||
return callTokenizerAsync(currentRemoteTokenizerAPI(), str).then(resolve);
|
||||
case tokenizers.API_KOBOLD:
|
||||
return countTokensFromKoboldAPI(str, resolve);
|
||||
case tokenizers.API_TEXTGENERATIONWEBUI:
|
||||
return countTokensFromTextgenAPI(str, resolve);
|
||||
default: {
|
||||
const endpointUrl = TOKENIZER_URLS[type]?.count;
|
||||
if (!endpointUrl) {
|
||||
console.warn('Unknown tokenizer type', type);
|
||||
return resolve(apiFailureTokenCount(str));
|
||||
}
|
||||
return countTokensFromServer(endpointUrl, str, resolve);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the token count for a string using the current model tokenizer.
|
||||
* @param {string} str String to tokenize
|
||||
* @param {number | undefined} padding Optional padding tokens. Defaults to 0.
|
||||
* @returns {Promise<number>} Token count.
|
||||
*/
|
||||
export async function getTokenCountAsync(str, padding = undefined) {
|
||||
if (typeof str !== 'string' || !str?.length) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let tokenizerType = power_user.tokenizer;
|
||||
|
||||
if (main_api === 'openai') {
|
||||
if (padding === power_user.token_padding) {
|
||||
// For main "shadow" prompt building
|
||||
tokenizerType = tokenizers.NONE;
|
||||
} else {
|
||||
// For extensions and WI
|
||||
return counterWrapperOpenAIAsync(str);
|
||||
}
|
||||
}
|
||||
|
||||
if (tokenizerType === tokenizers.BEST_MATCH) {
|
||||
tokenizerType = getTokenizerBestMatch(main_api);
|
||||
}
|
||||
|
||||
if (padding === undefined) {
|
||||
padding = 0;
|
||||
}
|
||||
|
||||
const cacheObject = getTokenCacheObject();
|
||||
const hash = getStringHash(str);
|
||||
const cacheKey = `${tokenizerType}-${hash}+${padding}`;
|
||||
|
||||
if (typeof cacheObject[cacheKey] === 'number') {
|
||||
return cacheObject[cacheKey];
|
||||
}
|
||||
|
||||
const result = (await callTokenizerAsync(tokenizerType, str)) + padding;
|
||||
|
||||
if (isNaN(result)) {
|
||||
console.warn('Token count calculation returned NaN');
|
||||
return 0;
|
||||
}
|
||||
|
||||
cacheObject[cacheKey] = result;
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the token count for a string using the current model tokenizer.
|
||||
* @param {string} str String to tokenize
|
||||
* @param {number | undefined} padding Optional padding tokens. Defaults to 0.
|
||||
* @returns {number} Token count.
|
||||
* @deprecated Use getTokenCountAsync instead.
|
||||
*/
|
||||
export function getTokenCount(str, padding = undefined) {
|
||||
if (typeof str !== 'string' || !str?.length) {
|
||||
@ -310,12 +392,23 @@ export function getTokenCount(str, padding = undefined) {
|
||||
* Gets the token count for a string using the OpenAI tokenizer.
|
||||
* @param {string} text Text to tokenize.
|
||||
* @returns {number} Token count.
|
||||
* @deprecated Use counterWrapperOpenAIAsync instead.
|
||||
*/
|
||||
function counterWrapperOpenAI(text) {
|
||||
const message = { role: 'system', content: text };
|
||||
return countTokensOpenAI(message, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the token count for a string using the OpenAI tokenizer.
|
||||
* @param {string} text Text to tokenize.
|
||||
* @returns {Promise<number>} Token count.
|
||||
*/
|
||||
function counterWrapperOpenAIAsync(text) {
|
||||
const message = { role: 'system', content: text };
|
||||
return countTokensOpenAIAsync(message, true);
|
||||
}
|
||||
|
||||
export function getTokenizerModel() {
|
||||
// OpenAI models always provide their own tokenizer
|
||||
if (oai_settings.chat_completion_source == chat_completion_sources.OPENAI) {
|
||||
@ -410,6 +503,7 @@ export function getTokenizerModel() {
|
||||
|
||||
/**
|
||||
* @param {any[] | Object} messages
|
||||
* @deprecated Use countTokensOpenAIAsync instead.
|
||||
*/
|
||||
export function countTokensOpenAI(messages, full = false) {
|
||||
const shouldTokenizeAI21 = oai_settings.chat_completion_source === chat_completion_sources.AI21 && oai_settings.use_ai21_tokenizer;
|
||||
@ -466,6 +560,66 @@ export function countTokensOpenAI(messages, full = false) {
|
||||
return token_count;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the token count for a message using the OpenAI tokenizer.
|
||||
* @param {object[]|object} messages
|
||||
* @param {boolean} full
|
||||
* @returns {Promise<number>} Token count.
|
||||
*/
|
||||
export async function countTokensOpenAIAsync(messages, full = false) {
|
||||
const shouldTokenizeAI21 = oai_settings.chat_completion_source === chat_completion_sources.AI21 && oai_settings.use_ai21_tokenizer;
|
||||
const shouldTokenizeGoogle = oai_settings.chat_completion_source === chat_completion_sources.MAKERSUITE && oai_settings.use_google_tokenizer;
|
||||
let tokenizerEndpoint = '';
|
||||
if (shouldTokenizeAI21) {
|
||||
tokenizerEndpoint = '/api/tokenizers/ai21/count';
|
||||
} else if (shouldTokenizeGoogle) {
|
||||
tokenizerEndpoint = `/api/tokenizers/google/count?model=${getTokenizerModel()}`;
|
||||
} else {
|
||||
tokenizerEndpoint = `/api/tokenizers/openai/count?model=${getTokenizerModel()}`;
|
||||
}
|
||||
const cacheObject = getTokenCacheObject();
|
||||
|
||||
if (!Array.isArray(messages)) {
|
||||
messages = [messages];
|
||||
}
|
||||
|
||||
let token_count = -1;
|
||||
|
||||
for (const message of messages) {
|
||||
const model = getTokenizerModel();
|
||||
|
||||
if (model === 'claude' || shouldTokenizeAI21 || shouldTokenizeGoogle) {
|
||||
full = true;
|
||||
}
|
||||
|
||||
const hash = getStringHash(JSON.stringify(message));
|
||||
const cacheKey = `${model}-${hash}`;
|
||||
const cachedCount = cacheObject[cacheKey];
|
||||
|
||||
if (typeof cachedCount === 'number') {
|
||||
token_count += cachedCount;
|
||||
}
|
||||
|
||||
else {
|
||||
const data = await jQuery.ajax({
|
||||
async: true,
|
||||
type: 'POST', //
|
||||
url: tokenizerEndpoint,
|
||||
data: JSON.stringify([message]),
|
||||
dataType: 'json',
|
||||
contentType: 'application/json',
|
||||
});
|
||||
|
||||
token_count += Number(data.token_count);
|
||||
cacheObject[cacheKey] = Number(data.token_count);
|
||||
}
|
||||
}
|
||||
|
||||
if (!full) token_count -= 2;
|
||||
|
||||
return token_count;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the token cache object for the current chat.
|
||||
* @returns {Object} Token cache object for the current chat.
|
||||
@ -495,13 +649,15 @@ function getTokenCacheObject() {
|
||||
* Count tokens using the server API.
|
||||
* @param {string} endpoint API endpoint.
|
||||
* @param {string} str String to tokenize.
|
||||
* @param {function} [resolve] Promise resolve function.s
|
||||
* @returns {number} Token count.
|
||||
*/
|
||||
function countTokensFromServer(endpoint, str) {
|
||||
function countTokensFromServer(endpoint, str, resolve) {
|
||||
const isAsync = typeof resolve === 'function';
|
||||
let tokenCount = 0;
|
||||
|
||||
jQuery.ajax({
|
||||
async: false,
|
||||
async: isAsync,
|
||||
type: 'POST',
|
||||
url: endpoint,
|
||||
data: JSON.stringify({ text: str }),
|
||||
@ -513,6 +669,8 @@ function countTokensFromServer(endpoint, str) {
|
||||
} else {
|
||||
tokenCount = apiFailureTokenCount(str);
|
||||
}
|
||||
|
||||
isAsync && resolve(tokenCount);
|
||||
},
|
||||
});
|
||||
|
||||
@ -522,13 +680,15 @@ function countTokensFromServer(endpoint, str) {
|
||||
/**
|
||||
* Count tokens using the AI provider's API.
|
||||
* @param {string} str String to tokenize.
|
||||
* @param {function} [resolve] Promise resolve function.
|
||||
* @returns {number} Token count.
|
||||
*/
|
||||
function countTokensFromKoboldAPI(str) {
|
||||
function countTokensFromKoboldAPI(str, resolve) {
|
||||
const isAsync = typeof resolve === 'function';
|
||||
let tokenCount = 0;
|
||||
|
||||
jQuery.ajax({
|
||||
async: false,
|
||||
async: isAsync,
|
||||
type: 'POST',
|
||||
url: TOKENIZER_URLS[tokenizers.API_KOBOLD].count,
|
||||
data: JSON.stringify({
|
||||
@ -543,6 +703,8 @@ function countTokensFromKoboldAPI(str) {
|
||||
} else {
|
||||
tokenCount = apiFailureTokenCount(str);
|
||||
}
|
||||
|
||||
isAsync && resolve(tokenCount);
|
||||
},
|
||||
});
|
||||
|
||||
@ -561,13 +723,15 @@ function getTextgenAPITokenizationParams(str) {
|
||||
/**
|
||||
* Count tokens using the AI provider's API.
|
||||
* @param {string} str String to tokenize.
|
||||
* @param {function} [resolve] Promise resolve function.
|
||||
* @returns {number} Token count.
|
||||
*/
|
||||
function countTokensFromTextgenAPI(str) {
|
||||
function countTokensFromTextgenAPI(str, resolve) {
|
||||
const isAsync = typeof resolve === 'function';
|
||||
let tokenCount = 0;
|
||||
|
||||
jQuery.ajax({
|
||||
async: false,
|
||||
async: isAsync,
|
||||
type: 'POST',
|
||||
url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].count,
|
||||
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
|
||||
@ -579,6 +743,8 @@ function countTokensFromTextgenAPI(str) {
|
||||
} else {
|
||||
tokenCount = apiFailureTokenCount(str);
|
||||
}
|
||||
|
||||
isAsync && resolve(tokenCount);
|
||||
},
|
||||
});
|
||||
|
||||
@ -605,12 +771,14 @@ function apiFailureTokenCount(str) {
|
||||
* Calls the underlying tokenizer model to encode a string to tokens.
|
||||
* @param {string} endpoint API endpoint.
|
||||
* @param {string} str String to tokenize.
|
||||
* @param {function} [resolve] Promise resolve function.
|
||||
* @returns {number[]} Array of token ids.
|
||||
*/
|
||||
function getTextTokensFromServer(endpoint, str) {
|
||||
function getTextTokensFromServer(endpoint, str, resolve) {
|
||||
const isAsync = typeof resolve === 'function';
|
||||
let ids = [];
|
||||
jQuery.ajax({
|
||||
async: false,
|
||||
async: isAsync,
|
||||
type: 'POST',
|
||||
url: endpoint,
|
||||
data: JSON.stringify({ text: str }),
|
||||
@ -623,6 +791,8 @@ function getTextTokensFromServer(endpoint, str) {
|
||||
if (Array.isArray(data.chunks)) {
|
||||
Object.defineProperty(ids, 'chunks', { value: data.chunks });
|
||||
}
|
||||
|
||||
isAsync && resolve(ids);
|
||||
},
|
||||
});
|
||||
return ids;
|
||||
@ -631,12 +801,14 @@ function getTextTokensFromServer(endpoint, str) {
|
||||
/**
|
||||
* Calls the AI provider's tokenize API to encode a string to tokens.
|
||||
* @param {string} str String to tokenize.
|
||||
* @param {function} [resolve] Promise resolve function.
|
||||
* @returns {number[]} Array of token ids.
|
||||
*/
|
||||
function getTextTokensFromTextgenAPI(str) {
|
||||
function getTextTokensFromTextgenAPI(str, resolve) {
|
||||
const isAsync = typeof resolve === 'function';
|
||||
let ids = [];
|
||||
jQuery.ajax({
|
||||
async: false,
|
||||
async: isAsync,
|
||||
type: 'POST',
|
||||
url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].encode,
|
||||
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
|
||||
@ -644,6 +816,7 @@ function getTextTokensFromTextgenAPI(str) {
|
||||
contentType: 'application/json',
|
||||
success: function (data) {
|
||||
ids = data.ids;
|
||||
isAsync && resolve(ids);
|
||||
},
|
||||
});
|
||||
return ids;
|
||||
@ -652,13 +825,15 @@ function getTextTokensFromTextgenAPI(str) {
|
||||
/**
|
||||
* Calls the AI provider's tokenize API to encode a string to tokens.
|
||||
* @param {string} str String to tokenize.
|
||||
* @param {function} [resolve] Promise resolve function.
|
||||
* @returns {number[]} Array of token ids.
|
||||
*/
|
||||
function getTextTokensFromKoboldAPI(str) {
|
||||
function getTextTokensFromKoboldAPI(str, resolve) {
|
||||
const isAsync = typeof resolve === 'function';
|
||||
let ids = [];
|
||||
|
||||
jQuery.ajax({
|
||||
async: false,
|
||||
async: isAsync,
|
||||
type: 'POST',
|
||||
url: TOKENIZER_URLS[tokenizers.API_KOBOLD].encode,
|
||||
data: JSON.stringify({
|
||||
@ -669,6 +844,7 @@ function getTextTokensFromKoboldAPI(str) {
|
||||
contentType: 'application/json',
|
||||
success: function (data) {
|
||||
ids = data.ids;
|
||||
isAsync && resolve(ids);
|
||||
},
|
||||
});
|
||||
|
||||
@ -679,13 +855,15 @@ function getTextTokensFromKoboldAPI(str) {
|
||||
* Calls the underlying tokenizer model to decode token ids to text.
|
||||
* @param {string} endpoint API endpoint.
|
||||
* @param {number[]} ids Array of token ids
|
||||
* @param {function} [resolve] Promise resolve function.
|
||||
* @returns {({ text: string, chunks?: string[] })} Decoded token text as a single string and individual chunks (if available).
|
||||
*/
|
||||
function decodeTextTokensFromServer(endpoint, ids) {
|
||||
function decodeTextTokensFromServer(endpoint, ids, resolve) {
|
||||
const isAsync = typeof resolve === 'function';
|
||||
let text = '';
|
||||
let chunks = [];
|
||||
jQuery.ajax({
|
||||
async: false,
|
||||
async: isAsync,
|
||||
type: 'POST',
|
||||
url: endpoint,
|
||||
data: JSON.stringify({ ids: ids }),
|
||||
@ -694,6 +872,7 @@ function decodeTextTokensFromServer(endpoint, ids) {
|
||||
success: function (data) {
|
||||
text = data.text;
|
||||
chunks = data.chunks;
|
||||
isAsync && resolve({ text, chunks });
|
||||
},
|
||||
});
|
||||
return { text, chunks };
|
||||
|
Reference in New Issue
Block a user