#2085 Implement async token counting

This commit is contained in:
Cohee
2024-04-13 21:05:31 +03:00
parent ba397dd2a8
commit 1c4bad35b2
5 changed files with 229 additions and 26 deletions

View File

@ -256,11 +256,93 @@ function callTokenizer(type, str) {
}
}
/**
* Calls the underlying tokenizer model to the token count for a string.
* @param {number} type Tokenizer type.
* @param {string} str String to tokenize.
* @returns {Promise<number>} Token count.
*/
function callTokenizerAsync(type, str) {
return new Promise(resolve => {
if (type === tokenizers.NONE) {
return resolve(guesstimate(str));
}
switch (type) {
case tokenizers.API_CURRENT:
return callTokenizerAsync(currentRemoteTokenizerAPI(), str).then(resolve);
case tokenizers.API_KOBOLD:
return countTokensFromKoboldAPI(str, resolve);
case tokenizers.API_TEXTGENERATIONWEBUI:
return countTokensFromTextgenAPI(str, resolve);
default: {
const endpointUrl = TOKENIZER_URLS[type]?.count;
if (!endpointUrl) {
console.warn('Unknown tokenizer type', type);
return resolve(apiFailureTokenCount(str));
}
return countTokensFromServer(endpointUrl, str, resolve);
}
}
});
}
/**
* Gets the token count for a string using the current model tokenizer.
* @param {string} str String to tokenize
* @param {number | undefined} padding Optional padding tokens. Defaults to 0.
* @returns {Promise<number>} Token count.
*/
export async function getTokenCountAsync(str, padding = undefined) {
if (typeof str !== 'string' || !str?.length) {
return 0;
}
let tokenizerType = power_user.tokenizer;
if (main_api === 'openai') {
if (padding === power_user.token_padding) {
// For main "shadow" prompt building
tokenizerType = tokenizers.NONE;
} else {
// For extensions and WI
return counterWrapperOpenAIAsync(str);
}
}
if (tokenizerType === tokenizers.BEST_MATCH) {
tokenizerType = getTokenizerBestMatch(main_api);
}
if (padding === undefined) {
padding = 0;
}
const cacheObject = getTokenCacheObject();
const hash = getStringHash(str);
const cacheKey = `${tokenizerType}-${hash}+${padding}`;
if (typeof cacheObject[cacheKey] === 'number') {
return cacheObject[cacheKey];
}
const result = (await callTokenizerAsync(tokenizerType, str)) + padding;
if (isNaN(result)) {
console.warn('Token count calculation returned NaN');
return 0;
}
cacheObject[cacheKey] = result;
return result;
}
/**
* Gets the token count for a string using the current model tokenizer.
* @param {string} str String to tokenize
* @param {number | undefined} padding Optional padding tokens. Defaults to 0.
* @returns {number} Token count.
* @deprecated Use getTokenCountAsync instead.
*/
export function getTokenCount(str, padding = undefined) {
if (typeof str !== 'string' || !str?.length) {
@ -310,12 +392,23 @@ export function getTokenCount(str, padding = undefined) {
* Gets the token count for a string using the OpenAI tokenizer.
* @param {string} text Text to tokenize.
* @returns {number} Token count.
* @deprecated Use counterWrapperOpenAIAsync instead.
*/
function counterWrapperOpenAI(text) {
const message = { role: 'system', content: text };
return countTokensOpenAI(message, true);
}
/**
* Gets the token count for a string using the OpenAI tokenizer.
* @param {string} text Text to tokenize.
* @returns {Promise<number>} Token count.
*/
function counterWrapperOpenAIAsync(text) {
const message = { role: 'system', content: text };
return countTokensOpenAIAsync(message, true);
}
export function getTokenizerModel() {
// OpenAI models always provide their own tokenizer
if (oai_settings.chat_completion_source == chat_completion_sources.OPENAI) {
@ -410,6 +503,7 @@ export function getTokenizerModel() {
/**
* @param {any[] | Object} messages
* @deprecated Use countTokensOpenAIAsync instead.
*/
export function countTokensOpenAI(messages, full = false) {
const shouldTokenizeAI21 = oai_settings.chat_completion_source === chat_completion_sources.AI21 && oai_settings.use_ai21_tokenizer;
@ -466,6 +560,66 @@ export function countTokensOpenAI(messages, full = false) {
return token_count;
}
/**
* Returns the token count for a message using the OpenAI tokenizer.
* @param {object[]|object} messages
* @param {boolean} full
* @returns {Promise<number>} Token count.
*/
export async function countTokensOpenAIAsync(messages, full = false) {
const shouldTokenizeAI21 = oai_settings.chat_completion_source === chat_completion_sources.AI21 && oai_settings.use_ai21_tokenizer;
const shouldTokenizeGoogle = oai_settings.chat_completion_source === chat_completion_sources.MAKERSUITE && oai_settings.use_google_tokenizer;
let tokenizerEndpoint = '';
if (shouldTokenizeAI21) {
tokenizerEndpoint = '/api/tokenizers/ai21/count';
} else if (shouldTokenizeGoogle) {
tokenizerEndpoint = `/api/tokenizers/google/count?model=${getTokenizerModel()}`;
} else {
tokenizerEndpoint = `/api/tokenizers/openai/count?model=${getTokenizerModel()}`;
}
const cacheObject = getTokenCacheObject();
if (!Array.isArray(messages)) {
messages = [messages];
}
let token_count = -1;
for (const message of messages) {
const model = getTokenizerModel();
if (model === 'claude' || shouldTokenizeAI21 || shouldTokenizeGoogle) {
full = true;
}
const hash = getStringHash(JSON.stringify(message));
const cacheKey = `${model}-${hash}`;
const cachedCount = cacheObject[cacheKey];
if (typeof cachedCount === 'number') {
token_count += cachedCount;
}
else {
const data = await jQuery.ajax({
async: true,
type: 'POST', //
url: tokenizerEndpoint,
data: JSON.stringify([message]),
dataType: 'json',
contentType: 'application/json',
});
token_count += Number(data.token_count);
cacheObject[cacheKey] = Number(data.token_count);
}
}
if (!full) token_count -= 2;
return token_count;
}
/**
* Gets the token cache object for the current chat.
* @returns {Object} Token cache object for the current chat.
@ -495,13 +649,15 @@ function getTokenCacheObject() {
* Count tokens using the server API.
* @param {string} endpoint API endpoint.
* @param {string} str String to tokenize.
* @param {function} [resolve] Promise resolve function.s
* @returns {number} Token count.
*/
function countTokensFromServer(endpoint, str) {
function countTokensFromServer(endpoint, str, resolve) {
const isAsync = typeof resolve === 'function';
let tokenCount = 0;
jQuery.ajax({
async: false,
async: isAsync,
type: 'POST',
url: endpoint,
data: JSON.stringify({ text: str }),
@ -513,6 +669,8 @@ function countTokensFromServer(endpoint, str) {
} else {
tokenCount = apiFailureTokenCount(str);
}
isAsync && resolve(tokenCount);
},
});
@ -522,13 +680,15 @@ function countTokensFromServer(endpoint, str) {
/**
* Count tokens using the AI provider's API.
* @param {string} str String to tokenize.
* @param {function} [resolve] Promise resolve function.
* @returns {number} Token count.
*/
function countTokensFromKoboldAPI(str) {
function countTokensFromKoboldAPI(str, resolve) {
const isAsync = typeof resolve === 'function';
let tokenCount = 0;
jQuery.ajax({
async: false,
async: isAsync,
type: 'POST',
url: TOKENIZER_URLS[tokenizers.API_KOBOLD].count,
data: JSON.stringify({
@ -543,6 +703,8 @@ function countTokensFromKoboldAPI(str) {
} else {
tokenCount = apiFailureTokenCount(str);
}
isAsync && resolve(tokenCount);
},
});
@ -561,13 +723,15 @@ function getTextgenAPITokenizationParams(str) {
/**
* Count tokens using the AI provider's API.
* @param {string} str String to tokenize.
* @param {function} [resolve] Promise resolve function.
* @returns {number} Token count.
*/
function countTokensFromTextgenAPI(str) {
function countTokensFromTextgenAPI(str, resolve) {
const isAsync = typeof resolve === 'function';
let tokenCount = 0;
jQuery.ajax({
async: false,
async: isAsync,
type: 'POST',
url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].count,
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
@ -579,6 +743,8 @@ function countTokensFromTextgenAPI(str) {
} else {
tokenCount = apiFailureTokenCount(str);
}
isAsync && resolve(tokenCount);
},
});
@ -605,12 +771,14 @@ function apiFailureTokenCount(str) {
* Calls the underlying tokenizer model to encode a string to tokens.
* @param {string} endpoint API endpoint.
* @param {string} str String to tokenize.
* @param {function} [resolve] Promise resolve function.
* @returns {number[]} Array of token ids.
*/
function getTextTokensFromServer(endpoint, str) {
function getTextTokensFromServer(endpoint, str, resolve) {
const isAsync = typeof resolve === 'function';
let ids = [];
jQuery.ajax({
async: false,
async: isAsync,
type: 'POST',
url: endpoint,
data: JSON.stringify({ text: str }),
@ -623,6 +791,8 @@ function getTextTokensFromServer(endpoint, str) {
if (Array.isArray(data.chunks)) {
Object.defineProperty(ids, 'chunks', { value: data.chunks });
}
isAsync && resolve(ids);
},
});
return ids;
@ -631,12 +801,14 @@ function getTextTokensFromServer(endpoint, str) {
/**
* Calls the AI provider's tokenize API to encode a string to tokens.
* @param {string} str String to tokenize.
* @param {function} [resolve] Promise resolve function.
* @returns {number[]} Array of token ids.
*/
function getTextTokensFromTextgenAPI(str) {
function getTextTokensFromTextgenAPI(str, resolve) {
const isAsync = typeof resolve === 'function';
let ids = [];
jQuery.ajax({
async: false,
async: isAsync,
type: 'POST',
url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].encode,
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
@ -644,6 +816,7 @@ function getTextTokensFromTextgenAPI(str) {
contentType: 'application/json',
success: function (data) {
ids = data.ids;
isAsync && resolve(ids);
},
});
return ids;
@ -652,13 +825,15 @@ function getTextTokensFromTextgenAPI(str) {
/**
* Calls the AI provider's tokenize API to encode a string to tokens.
* @param {string} str String to tokenize.
* @param {function} [resolve] Promise resolve function.
* @returns {number[]} Array of token ids.
*/
function getTextTokensFromKoboldAPI(str) {
function getTextTokensFromKoboldAPI(str, resolve) {
const isAsync = typeof resolve === 'function';
let ids = [];
jQuery.ajax({
async: false,
async: isAsync,
type: 'POST',
url: TOKENIZER_URLS[tokenizers.API_KOBOLD].encode,
data: JSON.stringify({
@ -669,6 +844,7 @@ function getTextTokensFromKoboldAPI(str) {
contentType: 'application/json',
success: function (data) {
ids = data.ids;
isAsync && resolve(ids);
},
});
@ -679,13 +855,15 @@ function getTextTokensFromKoboldAPI(str) {
* Calls the underlying tokenizer model to decode token ids to text.
* @param {string} endpoint API endpoint.
* @param {number[]} ids Array of token ids
* @param {function} [resolve] Promise resolve function.
* @returns {({ text: string, chunks?: string[] })} Decoded token text as a single string and individual chunks (if available).
*/
function decodeTextTokensFromServer(endpoint, ids) {
function decodeTextTokensFromServer(endpoint, ids, resolve) {
const isAsync = typeof resolve === 'function';
let text = '';
let chunks = [];
jQuery.ajax({
async: false,
async: isAsync,
type: 'POST',
url: endpoint,
data: JSON.stringify({ ids: ids }),
@ -694,6 +872,7 @@ function decodeTextTokensFromServer(endpoint, ids) {
success: function (data) {
text = data.text;
chunks = data.chunks;
isAsync && resolve({ text, chunks });
},
});
return { text, chunks };