Add padding once in getTokenCount

This means we don't have to pass the "padding" parameter into every
function so they can add the padding themselves--we can do it in just
one place instead.
This commit is contained in:
valadaptive
2023-12-09 20:53:16 -05:00
parent 2f2cd197cc
commit 014416546c

View File

@ -199,24 +199,23 @@ export function getTokenizerBestMatch(forApi) {
* Calls the underlying tokenizer model to the token count for a string. * Calls the underlying tokenizer model to the token count for a string.
* @param {number} type Tokenizer type. * @param {number} type Tokenizer type.
* @param {string} str String to tokenize. * @param {string} str String to tokenize.
* @param {number} padding Number of padding tokens.
* @returns {number} Token count. * @returns {number} Token count.
*/ */
function callTokenizer(type, str, padding) { function callTokenizer(type, str) {
if (type === tokenizers.NONE) return guesstimate(str) + padding; if (type === tokenizers.NONE) return guesstimate(str);
switch (type) { switch (type) {
case tokenizers.API_KOBOLD: case tokenizers.API_KOBOLD:
return countTokensFromKoboldAPI(str, padding); return countTokensFromKoboldAPI(str);
case tokenizers.API_TEXTGENERATIONWEBUI: case tokenizers.API_TEXTGENERATIONWEBUI:
return countTokensFromTextgenAPI(str, padding); return countTokensFromTextgenAPI(str);
default: { default: {
const endpointUrl = TOKENIZER_URLS[type]?.count; const endpointUrl = TOKENIZER_URLS[type]?.count;
if (!endpointUrl) { if (!endpointUrl) {
console.warn('Unknown tokenizer type', type); console.warn('Unknown tokenizer type', type);
return callTokenizer(tokenizers.NONE, str, padding); return callTokenizer(tokenizers.NONE, str);
} }
return countTokensFromServer(endpointUrl, str, padding); return countTokensFromServer(endpointUrl, str);
} }
} }
} }
@ -260,7 +259,7 @@ export function getTokenCount(str, padding = undefined) {
return cacheObject[cacheKey]; return cacheObject[cacheKey];
} }
const result = callTokenizer(tokenizerType, str, padding); const result = callTokenizer(tokenizerType, str) + padding;
if (isNaN(result)) { if (isNaN(result)) {
console.warn('Token count calculation returned NaN'); console.warn('Token count calculation returned NaN');
@ -436,10 +435,9 @@ function getTokenCacheObject() {
* Count tokens using the server API. * Count tokens using the server API.
* @param {string} endpoint API endpoint. * @param {string} endpoint API endpoint.
* @param {string} str String to tokenize. * @param {string} str String to tokenize.
* @param {number} padding Number of padding tokens. * @returns {number} Token count.
* @returns {number} Token count with padding.
*/ */
function countTokensFromServer(endpoint, str, padding) { function countTokensFromServer(endpoint, str) {
let tokenCount = 0; let tokenCount = 0;
jQuery.ajax({ jQuery.ajax({
@ -458,16 +456,15 @@ function countTokensFromServer(endpoint, str, padding) {
}, },
}); });
return tokenCount + padding; return tokenCount;
} }
/** /**
* Count tokens using the AI provider's API. * Count tokens using the AI provider's API.
* @param {string} str String to tokenize. * @param {string} str String to tokenize.
* @param {number} padding Number of padding tokens. * @returns {number} Token count.
* @returns {number} Token count with padding.
*/ */
function countTokensFromKoboldAPI(str, padding) { function countTokensFromKoboldAPI(str) {
let tokenCount = 0; let tokenCount = 0;
jQuery.ajax({ jQuery.ajax({
@ -489,7 +486,7 @@ function countTokensFromKoboldAPI(str, padding) {
}, },
}); });
return tokenCount + padding; return tokenCount;
} }
function getTextgenAPITokenizationParams(str) { function getTextgenAPITokenizationParams(str) {
@ -506,10 +503,9 @@ function getTextgenAPITokenizationParams(str) {
/** /**
* Count tokens using the AI provider's API. * Count tokens using the AI provider's API.
* @param {string} str String to tokenize. * @param {string} str String to tokenize.
* @param {number} padding Number of padding tokens. * @returns {number} Token count.
* @returns {number} Token count with padding.
*/ */
function countTokensFromTextgenAPI(str, padding) { function countTokensFromTextgenAPI(str) {
let tokenCount = 0; let tokenCount = 0;
jQuery.ajax({ jQuery.ajax({
@ -528,7 +524,7 @@ function countTokensFromTextgenAPI(str, padding) {
}, },
}); });
return tokenCount + padding; return tokenCount;
} }
function apiFailureTokenCount(str) { function apiFailureTokenCount(str) {