Separate remote and server tokenization code paths

This lets us remove extraneous API params from paths where they aren't
needed.
This commit is contained in:
valadaptive
2023-12-09 20:08:48 -05:00
parent ddd73a204a
commit 18177c147d

View File

@ -173,7 +173,7 @@ function callTokenizer(type, str, padding) {
case tokenizers.YI:
return countTokensFromServer('/api/tokenizers/yi/encode', str, padding);
case tokenizers.API:
return countTokensFromServer('/api/tokenizers/remote/encode', str, padding);
return countTokensFromRemoteAPI('/api/tokenizers/remote/encode', str, padding);
default:
console.warn('Unknown tokenizer type', type);
return callTokenizer(tokenizers.NONE, str, padding);
@ -392,6 +392,12 @@ function getTokenCacheObject() {
}
function getServerTokenizationParams(str) {
return {
text: str,
};
}
function getRemoteAPITokenizationParams(str) {
return {
text: str,
main_api,
@ -404,7 +410,7 @@ function getServerTokenizationParams(str) {
}
/**
* Counts token using the server API.
* Count tokens using the server API.
* @param {string} endpoint API endpoint.
* @param {string} str String to tokenize.
* @param {number} padding Number of padding tokens.
@ -424,18 +430,7 @@ function countTokensFromServer(endpoint, str, padding) {
if (typeof data.count === 'number') {
tokenCount = data.count;
} else {
tokenCount = guesstimate(str);
console.error('Error counting tokens');
if (!sessionStorage.getItem(TOKENIZER_WARNING_KEY)) {
toastr.warning(
'Your selected API doesn\'t support the tokenization endpoint. Using estimated counts.',
'Error counting tokens',
{ timeOut: 10000, preventDuplicates: true },
);
sessionStorage.setItem(TOKENIZER_WARNING_KEY, String(true));
}
tokenCount = apiFailureTokenCount(str);
}
},
});
@ -443,6 +438,51 @@ function countTokensFromServer(endpoint, str, padding) {
return tokenCount + padding;
}
/**
* Count tokens using the AI provider's API.
* @param {string} endpoint API endpoint.
* @param {string} str String to tokenize.
* @param {number} padding Number of padding tokens.
* @returns {number} Token count with padding.
*/
function countTokensFromRemoteAPI(endpoint, str, padding) {
let tokenCount = 0;
jQuery.ajax({
async: false,
type: 'POST',
url: endpoint,
data: JSON.stringify(getRemoteAPITokenizationParams(str)),
dataType: 'json',
contentType: 'application/json',
success: function (data) {
if (typeof data.count === 'number') {
tokenCount = data.count;
} else {
tokenCount = apiFailureTokenCount(str);
}
},
});
return tokenCount + padding;
}
function apiFailureTokenCount(str) {
console.error('Error counting tokens');
if (!sessionStorage.getItem(TOKENIZER_WARNING_KEY)) {
toastr.warning(
'Your selected API doesn\'t support the tokenization endpoint. Using estimated counts.',
'Error counting tokens',
{ timeOut: 10000, preventDuplicates: true },
);
sessionStorage.setItem(TOKENIZER_WARNING_KEY, String(true));
}
return guesstimate(str);
}
/**
* Calls the underlying tokenizer model to encode a string to tokens.
* @param {string} endpoint API endpoint.
@ -475,6 +515,29 @@ function getTextTokensFromServer(endpoint, str, model = '') {
return ids;
}
/**
* Calls the AI provider's tokenize API to encode a string to tokens.
* @param {string} endpoint API endpoint.
* @param {string} str String to tokenize.
* @param {string} model Tokenizer model.
* @returns {number[]} Array of token ids.
*/
function getTextTokensFromRemoteAPI(endpoint, str, model = '') {
let ids = [];
jQuery.ajax({
async: false,
type: 'POST',
url: endpoint,
data: JSON.stringify(getRemoteAPITokenizationParams(str)),
dataType: 'json',
contentType: 'application/json',
success: function (data) {
ids = data.ids;
},
});
return ids;
}
/**
* Calls the underlying tokenizer model to decode token ids to text.
* @param {string} endpoint API endpoint.
@ -525,7 +588,7 @@ export function getTextTokens(tokenizerType, str) {
return getTextTokensFromServer('/api/tokenizers/openai/encode', str, model);
}
case tokenizers.API:
return getTextTokensFromServer('/api/tokenizers/remote/encode', str);
return getTextTokensFromRemoteAPI('/api/tokenizers/remote/encode', str);
default:
console.warn('Calling getTextTokens with unsupported tokenizer type', tokenizerType);
return [];