Separate remote and server tokenization code paths

This lets us remove extraneous API params from paths where they aren't
needed.
This commit is contained in:
valadaptive
2023-12-09 20:08:48 -05:00
parent ddd73a204a
commit 18177c147d

View File

@ -173,7 +173,7 @@ function callTokenizer(type, str, padding) {
case tokenizers.YI: case tokenizers.YI:
return countTokensFromServer('/api/tokenizers/yi/encode', str, padding); return countTokensFromServer('/api/tokenizers/yi/encode', str, padding);
case tokenizers.API: case tokenizers.API:
return countTokensFromServer('/api/tokenizers/remote/encode', str, padding); return countTokensFromRemoteAPI('/api/tokenizers/remote/encode', str, padding);
default: default:
console.warn('Unknown tokenizer type', type); console.warn('Unknown tokenizer type', type);
return callTokenizer(tokenizers.NONE, str, padding); return callTokenizer(tokenizers.NONE, str, padding);
@ -392,6 +392,12 @@ function getTokenCacheObject() {
} }
function getServerTokenizationParams(str) { function getServerTokenizationParams(str) {
return {
text: str,
};
}
function getRemoteAPITokenizationParams(str) {
return { return {
text: str, text: str,
main_api, main_api,
@ -404,7 +410,7 @@ function getServerTokenizationParams(str) {
} }
/** /**
* Counts token using the server API. * Count tokens using the server API.
* @param {string} endpoint API endpoint. * @param {string} endpoint API endpoint.
* @param {string} str String to tokenize. * @param {string} str String to tokenize.
* @param {number} padding Number of padding tokens. * @param {number} padding Number of padding tokens.
@ -424,18 +430,7 @@ function countTokensFromServer(endpoint, str, padding) {
if (typeof data.count === 'number') { if (typeof data.count === 'number') {
tokenCount = data.count; tokenCount = data.count;
} else { } else {
tokenCount = guesstimate(str); tokenCount = apiFailureTokenCount(str);
console.error('Error counting tokens');
if (!sessionStorage.getItem(TOKENIZER_WARNING_KEY)) {
toastr.warning(
'Your selected API doesn\'t support the tokenization endpoint. Using estimated counts.',
'Error counting tokens',
{ timeOut: 10000, preventDuplicates: true },
);
sessionStorage.setItem(TOKENIZER_WARNING_KEY, String(true));
}
} }
}, },
}); });
@ -443,6 +438,51 @@ function countTokensFromServer(endpoint, str, padding) {
return tokenCount + padding; return tokenCount + padding;
} }
/**
* Count tokens using the AI provider's API.
* @param {string} endpoint API endpoint.
* @param {string} str String to tokenize.
* @param {number} padding Number of padding tokens.
* @returns {number} Token count with padding.
*/
function countTokensFromRemoteAPI(endpoint, str, padding) {
let tokenCount = 0;
jQuery.ajax({
async: false,
type: 'POST',
url: endpoint,
data: JSON.stringify(getRemoteAPITokenizationParams(str)),
dataType: 'json',
contentType: 'application/json',
success: function (data) {
if (typeof data.count === 'number') {
tokenCount = data.count;
} else {
tokenCount = apiFailureTokenCount(str);
}
},
});
return tokenCount + padding;
}
function apiFailureTokenCount(str) {
console.error('Error counting tokens');
if (!sessionStorage.getItem(TOKENIZER_WARNING_KEY)) {
toastr.warning(
'Your selected API doesn\'t support the tokenization endpoint. Using estimated counts.',
'Error counting tokens',
{ timeOut: 10000, preventDuplicates: true },
);
sessionStorage.setItem(TOKENIZER_WARNING_KEY, String(true));
}
return guesstimate(str);
}
/** /**
* Calls the underlying tokenizer model to encode a string to tokens. * Calls the underlying tokenizer model to encode a string to tokens.
* @param {string} endpoint API endpoint. * @param {string} endpoint API endpoint.
@ -475,6 +515,29 @@ function getTextTokensFromServer(endpoint, str, model = '') {
return ids; return ids;
} }
/**
* Calls the AI provider's tokenize API to encode a string to tokens.
* @param {string} endpoint API endpoint.
* @param {string} str String to tokenize.
* @param {string} model Tokenizer model.
* @returns {number[]} Array of token ids.
*/
function getTextTokensFromRemoteAPI(endpoint, str, model = '') {
let ids = [];
jQuery.ajax({
async: false,
type: 'POST',
url: endpoint,
data: JSON.stringify(getRemoteAPITokenizationParams(str)),
dataType: 'json',
contentType: 'application/json',
success: function (data) {
ids = data.ids;
},
});
return ids;
}
/** /**
* Calls the underlying tokenizer model to decode token ids to text. * Calls the underlying tokenizer model to decode token ids to text.
* @param {string} endpoint API endpoint. * @param {string} endpoint API endpoint.
@ -525,7 +588,7 @@ export function getTextTokens(tokenizerType, str) {
return getTextTokensFromServer('/api/tokenizers/openai/encode', str, model); return getTextTokensFromServer('/api/tokenizers/openai/encode', str, model);
} }
case tokenizers.API: case tokenizers.API:
return getTextTokensFromServer('/api/tokenizers/remote/encode', str); return getTextTokensFromRemoteAPI('/api/tokenizers/remote/encode', str);
default: default:
console.warn('Calling getTextTokens with unsupported tokenizer type', tokenizerType); console.warn('Calling getTextTokens with unsupported tokenizer type', tokenizerType);
return []; return [];