Add padding once in getTokenCount

This means we don't have to pass the "padding" parameter into every
function so they can add the padding themselves--we can do it in just
one place instead.
This commit is contained in:
valadaptive 2023-12-09 20:53:16 -05:00
parent 2f2cd197cc
commit 014416546c
1 changed files with 16 additions and 20 deletions

View File

@ -199,24 +199,23 @@ export function getTokenizerBestMatch(forApi) {
* Calls the underlying tokenizer model to the token count for a string.
* @param {number} type Tokenizer type.
* @param {string} str String to tokenize.
* @param {number} padding Number of padding tokens.
* @returns {number} Token count.
*/
function callTokenizer(type, str, padding) {
if (type === tokenizers.NONE) return guesstimate(str) + padding;
function callTokenizer(type, str) {
if (type === tokenizers.NONE) return guesstimate(str);
switch (type) {
case tokenizers.API_KOBOLD:
return countTokensFromKoboldAPI(str, padding);
return countTokensFromKoboldAPI(str);
case tokenizers.API_TEXTGENERATIONWEBUI:
return countTokensFromTextgenAPI(str, padding);
return countTokensFromTextgenAPI(str);
default: {
const endpointUrl = TOKENIZER_URLS[type]?.count;
if (!endpointUrl) {
console.warn('Unknown tokenizer type', type);
return callTokenizer(tokenizers.NONE, str, padding);
return callTokenizer(tokenizers.NONE, str);
}
return countTokensFromServer(endpointUrl, str, padding);
return countTokensFromServer(endpointUrl, str);
}
}
}
@ -260,7 +259,7 @@ export function getTokenCount(str, padding = undefined) {
return cacheObject[cacheKey];
}
const result = callTokenizer(tokenizerType, str, padding);
const result = callTokenizer(tokenizerType, str) + padding;
if (isNaN(result)) {
console.warn('Token count calculation returned NaN');
@ -436,10 +435,9 @@ function getTokenCacheObject() {
* Count tokens using the server API.
* @param {string} endpoint API endpoint.
* @param {string} str String to tokenize.
* @param {number} padding Number of padding tokens.
* @returns {number} Token count with padding.
* @returns {number} Token count.
*/
function countTokensFromServer(endpoint, str, padding) {
function countTokensFromServer(endpoint, str) {
let tokenCount = 0;
jQuery.ajax({
@ -458,16 +456,15 @@ function countTokensFromServer(endpoint, str, padding) {
},
});
return tokenCount + padding;
return tokenCount;
}
/**
* Count tokens using the AI provider's API.
* @param {string} str String to tokenize.
* @param {number} padding Number of padding tokens.
* @returns {number} Token count with padding.
* @returns {number} Token count.
*/
function countTokensFromKoboldAPI(str, padding) {
function countTokensFromKoboldAPI(str) {
let tokenCount = 0;
jQuery.ajax({
@ -489,7 +486,7 @@ function countTokensFromKoboldAPI(str, padding) {
},
});
return tokenCount + padding;
return tokenCount;
}
function getTextgenAPITokenizationParams(str) {
@ -506,10 +503,9 @@ function getTextgenAPITokenizationParams(str) {
/**
* Count tokens using the AI provider's API.
* @param {string} str String to tokenize.
* @param {number} padding Number of padding tokens.
* @returns {number} Token count with padding.
* @returns {number} Token count.
*/
function countTokensFromTextgenAPI(str, padding) {
function countTokensFromTextgenAPI(str) {
let tokenCount = 0;
jQuery.ajax({
@ -528,7 +524,7 @@ function countTokensFromTextgenAPI(str, padding) {
},
});
return tokenCount + padding;
return tokenCount;
}
function apiFailureTokenCount(str) {