Separate textgen and Kobold tokenization APIs

They function differently and have different logic and API parameters,
so it makes sense to count them as two different APIs. Kobold's API
doesn't return tokens, so it can only be used to count them.

There's still a lot of duplicate code which I will clean up in the
following commits.
This commit is contained in:
valadaptive
2023-12-09 20:20:53 -05:00
parent 18177c147d
commit 7486ab3886

View File

@ -18,9 +18,10 @@ export const tokenizers = {
LLAMA: 3, LLAMA: 3,
NERD: 4, NERD: 4,
NERD2: 5, NERD2: 5,
API: 6, API_KOBOLD: 6,
MISTRAL: 7, MISTRAL: 7,
YI: 8, YI: 8,
API_TEXTGENERATIONWEBUI: 9,
BEST_MATCH: 99, BEST_MATCH: 99,
}; };
@ -135,11 +136,11 @@ export function getTokenizerBestMatch(forApi) {
if (!hasTokenizerError && isConnected) { if (!hasTokenizerError && isConnected) {
if (forApi === 'kobold' && kai_flags.can_use_tokenization) { if (forApi === 'kobold' && kai_flags.can_use_tokenization) {
return tokenizers.API; return tokenizers.API_KOBOLD;
} }
if (forApi === 'textgenerationwebui' && isTokenizerSupported) { if (forApi === 'textgenerationwebui' && isTokenizerSupported) {
return tokenizers.API; return tokenizers.API_TEXTGENERATIONWEBUI;
} }
} }
@ -172,8 +173,10 @@ function callTokenizer(type, str, padding) {
return countTokensFromServer('/api/tokenizers/mistral/encode', str, padding); return countTokensFromServer('/api/tokenizers/mistral/encode', str, padding);
case tokenizers.YI: case tokenizers.YI:
return countTokensFromServer('/api/tokenizers/yi/encode', str, padding); return countTokensFromServer('/api/tokenizers/yi/encode', str, padding);
case tokenizers.API: case tokenizers.API_KOBOLD:
return countTokensFromRemoteAPI('/api/tokenizers/remote/encode', str, padding); return countTokensFromKoboldAPI('/api/tokenizers/remote/encode', str, padding);
case tokenizers.API_TEXTGENERATIONWEBUI:
return countTokensFromTextgenAPI('/api/tokenizers/remote/encode', str, padding);
default: default:
console.warn('Unknown tokenizer type', type); console.warn('Unknown tokenizer type', type);
return callTokenizer(tokenizers.NONE, str, padding); return callTokenizer(tokenizers.NONE, str, padding);
@ -397,13 +400,21 @@ function getServerTokenizationParams(str) {
}; };
} }
function getRemoteAPITokenizationParams(str) { function getKoboldAPITokenizationParams(str) {
return { return {
text: str, text: str,
main_api, main_api: 'kobold',
url: getAPIServerUrl(),
};
}
function getTextgenAPITokenizationParams(str) {
return {
text: str,
main_api: 'textgenerationwebui',
api_type: textgen_settings.type, api_type: textgen_settings.type,
url: getAPIServerUrl(), url: getAPIServerUrl(),
legacy_api: main_api === 'textgenerationwebui' && legacy_api:
textgen_settings.legacy_api && textgen_settings.legacy_api &&
textgen_settings.type !== MANCER, textgen_settings.type !== MANCER,
}; };
@ -445,14 +456,43 @@ function countTokensFromServer(endpoint, str, padding) {
* @param {number} padding Number of padding tokens. * @param {number} padding Number of padding tokens.
* @returns {number} Token count with padding. * @returns {number} Token count with padding.
*/ */
function countTokensFromRemoteAPI(endpoint, str, padding) { function countTokensFromKoboldAPI(endpoint, str, padding) {
let tokenCount = 0; let tokenCount = 0;
jQuery.ajax({ jQuery.ajax({
async: false, async: false,
type: 'POST', type: 'POST',
url: endpoint, url: endpoint,
data: JSON.stringify(getRemoteAPITokenizationParams(str)), data: JSON.stringify(getKoboldAPITokenizationParams(str)),
dataType: 'json',
contentType: 'application/json',
success: function (data) {
if (typeof data.count === 'number') {
tokenCount = data.count;
} else {
tokenCount = apiFailureTokenCount(str);
}
},
});
return tokenCount + padding;
}
/**
* Count tokens using the AI provider's API.
* @param {string} endpoint API endpoint.
* @param {string} str String to tokenize.
* @param {number} padding Number of padding tokens.
* @returns {number} Token count with padding.
*/
function countTokensFromTextgenAPI(endpoint, str, padding) {
let tokenCount = 0;
jQuery.ajax({
async: false,
type: 'POST',
url: endpoint,
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
dataType: 'json', dataType: 'json',
contentType: 'application/json', contentType: 'application/json',
success: function (data) { success: function (data) {
@ -519,16 +559,15 @@ function getTextTokensFromServer(endpoint, str, model = '') {
* Calls the AI provider's tokenize API to encode a string to tokens. * Calls the AI provider's tokenize API to encode a string to tokens.
* @param {string} endpoint API endpoint. * @param {string} endpoint API endpoint.
* @param {string} str String to tokenize. * @param {string} str String to tokenize.
* @param {string} model Tokenizer model.
* @returns {number[]} Array of token ids. * @returns {number[]} Array of token ids.
*/ */
function getTextTokensFromRemoteAPI(endpoint, str, model = '') { function getTextTokensFromTextgenAPI(endpoint, str) {
let ids = []; let ids = [];
jQuery.ajax({ jQuery.ajax({
async: false, async: false,
type: 'POST', type: 'POST',
url: endpoint, url: endpoint,
data: JSON.stringify(getRemoteAPITokenizationParams(str)), data: JSON.stringify(getTextgenAPITokenizationParams(str)),
dataType: 'json', dataType: 'json',
contentType: 'application/json', contentType: 'application/json',
success: function (data) { success: function (data) {
@ -587,8 +626,8 @@ export function getTextTokens(tokenizerType, str) {
const model = getTokenizerModel(); const model = getTokenizerModel();
return getTextTokensFromServer('/api/tokenizers/openai/encode', str, model); return getTextTokensFromServer('/api/tokenizers/openai/encode', str, model);
} }
case tokenizers.API: case tokenizers.API_TEXTGENERATIONWEBUI:
return getTextTokensFromRemoteAPI('/api/tokenizers/remote/encode', str); return getTextTokensFromTextgenAPI('/api/tokenizers/remote/encode', str);
default: default:
console.warn('Calling getTextTokens with unsupported tokenizer type', tokenizerType); console.warn('Calling getTextTokens with unsupported tokenizer type', tokenizerType);
return []; return [];