diff --git a/public/scripts/tokenizers.js b/public/scripts/tokenizers.js index d21c1abb0..1c8420616 100644 --- a/public/scripts/tokenizers.js +++ b/public/scripts/tokenizers.js @@ -34,6 +34,51 @@ export const SENTENCEPIECE_TOKENIZERS = [ //tokenizers.NERD2, ]; +const TOKENIZER_URLS = { + [tokenizers.GPT2]: { + encode: '/api/tokenizers/gpt2/encode', + decode: '/api/tokenizers/gpt2/decode', + count: '/api/tokenizers/gpt2/encode', + }, + [tokenizers.OPENAI]: { + encode: '/api/tokenizers/openai/encode', + decode: '/api/tokenizers/openai/decode', + count: '/api/tokenizers/openai/encode', + }, + [tokenizers.LLAMA]: { + encode: '/api/tokenizers/llama/encode', + decode: '/api/tokenizers/llama/decode', + count: '/api/tokenizers/llama/encode', + }, + [tokenizers.NERD]: { + encode: '/api/tokenizers/nerdstash/encode', + decode: '/api/tokenizers/nerdstash/decode', + count: '/api/tokenizers/nerdstash/encode', + }, + [tokenizers.NERD2]: { + encode: '/api/tokenizers/nerdstash_v2/encode', + decode: '/api/tokenizers/nerdstash_v2/decode', + count: '/api/tokenizers/nerdstash_v2/encode', + }, + [tokenizers.API_KOBOLD]: { + count: '/api/tokenizers/remote/kobold/count', + }, + [tokenizers.MISTRAL]: { + encode: '/api/tokenizers/mistral/encode', + decode: '/api/tokenizers/mistral/decode', + count: '/api/tokenizers/mistral/encode', + }, + [tokenizers.YI]: { + encode: '/api/tokenizers/yi/encode', + decode: '/api/tokenizers/yi/decode', + count: '/api/tokenizers/yi/encode', + }, + [tokenizers.API_TEXTGENERATIONWEBUI]: { + encode: '/api/tokenizers/remote/textgenerationwebui/encode', + count: '/api/tokenizers/remote/textgenerationwebui/encode', + }, +}; + const objectStore = new localforage.createInstance({ name: 'SillyTavern_ChatCompletions' }); let tokenCache = {}; @@ -158,28 +203,21 @@ export function getTokenizerBestMatch(forApi) { * @returns {number} Token count. */ function callTokenizer(type, str, padding) { + if (type === tokenizers.NONE) return guesstimate(str) + padding; + switch (type) { - case tokenizers.NONE: - return guesstimate(str) + padding; - case tokenizers.GPT2: - return countTokensFromServer('/api/tokenizers/gpt2/encode', str, padding); - case tokenizers.LLAMA: - return countTokensFromServer('/api/tokenizers/llama/encode', str, padding); - case tokenizers.NERD: - return countTokensFromServer('/api/tokenizers/nerdstash/encode', str, padding); - case tokenizers.NERD2: - return countTokensFromServer('/api/tokenizers/nerdstash_v2/encode', str, padding); - case tokenizers.MISTRAL: - return countTokensFromServer('/api/tokenizers/mistral/encode', str, padding); - case tokenizers.YI: - return countTokensFromServer('/api/tokenizers/yi/encode', str, padding); case tokenizers.API_KOBOLD: - return countTokensFromKoboldAPI('/api/tokenizers/remote/kobold/count', str, padding); + return countTokensFromKoboldAPI(str, padding); case tokenizers.API_TEXTGENERATIONWEBUI: - return countTokensFromTextgenAPI('/api/tokenizers/remote/textgenerationwebui/encode', str, padding); - default: - console.warn('Unknown tokenizer type', type); - return callTokenizer(tokenizers.NONE, str, padding); + return countTokensFromTextgenAPI(str, padding); + default: { + const endpointUrl = TOKENIZER_URLS[type]?.count; + if (!endpointUrl) { + console.warn('Unknown tokenizer type', type); + return callTokenizer(tokenizers.NONE, str, padding); + } + return countTokensFromServer(endpointUrl, str, padding); + } } } @@ -425,18 +463,17 @@ function countTokensFromServer(endpoint, str, padding) { /** * Count tokens using the AI provider's API. - * @param {string} endpoint API endpoint. * @param {string} str String to tokenize. * @param {number} padding Number of padding tokens. * @returns {number} Token count with padding. */ -function countTokensFromKoboldAPI(endpoint, str, padding) { +function countTokensFromKoboldAPI(str, padding) { let tokenCount = 0; jQuery.ajax({ async: false, type: 'POST', - url: endpoint, + url: TOKENIZER_URLS[tokenizers.API_KOBOLD].count, data: JSON.stringify({ text: str, url: api_server, @@ -468,18 +505,17 @@ function getTextgenAPITokenizationParams(str) { /** * Count tokens using the AI provider's API. - * @param {string} endpoint API endpoint. * @param {string} str String to tokenize. * @param {number} padding Number of padding tokens. * @returns {number} Token count with padding. */ -function countTokensFromTextgenAPI(endpoint, str, padding) { +function countTokensFromTextgenAPI(str, padding) { let tokenCount = 0; jQuery.ajax({ async: false, type: 'POST', - url: endpoint, + url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].count, data: JSON.stringify(getTextgenAPITokenizationParams(str)), dataType: 'json', contentType: 'application/json', @@ -515,14 +551,9 @@ function apiFailureTokenCount(str) { * Calls the underlying tokenizer model to encode a string to tokens. * @param {string} endpoint API endpoint. * @param {string} str String to tokenize. - * @param {string} model Tokenizer model. * @returns {number[]} Array of token ids. */ -function getTextTokensFromServer(endpoint, str, model = '') { - if (model) { - endpoint += `?model=${model}`; - } - +function getTextTokensFromServer(endpoint, str) { let ids = []; jQuery.ajax({ async: false, @@ -545,16 +576,15 @@ function getTextTokensFromServer(endpoint, str, model = '') { /** * Calls the AI provider's tokenize API to encode a string to tokens. - * @param {string} endpoint API endpoint. * @param {string} str String to tokenize. * @returns {number[]} Array of token ids. */ -function getTextTokensFromTextgenAPI(endpoint, str) { +function getTextTokensFromTextgenAPI(str) { let ids = []; jQuery.ajax({ async: false, type: 'POST', - url: endpoint, + url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].encode, data: JSON.stringify(getTextgenAPITokenizationParams(str)), dataType: 'json', contentType: 'application/json', @@ -570,11 +600,7 @@ function getTextTokensFromTextgenAPI(endpoint, str) { * @param {string} endpoint API endpoint. * @param {number[]} ids Array of token ids */ -function decodeTextTokensFromServer(endpoint, ids, model = '') { - if (model) { - endpoint += `?model=${model}`; - } - +function decodeTextTokensFromServer(endpoint, ids) { let text = ''; jQuery.ajax({ async: false, @@ -598,27 +624,24 @@ function decodeTextTokensFromServer(endpoint, ids, model = '') { */ export function getTextTokens(tokenizerType, str) { switch (tokenizerType) { - case tokenizers.GPT2: - return getTextTokensFromServer('/api/tokenizers/gpt2/encode', str); - case tokenizers.LLAMA: - return getTextTokensFromServer('/api/tokenizers/llama/encode', str); - case tokenizers.NERD: - return getTextTokensFromServer('/api/tokenizers/nerdstash/encode', str); - case tokenizers.NERD2: - return getTextTokensFromServer('/api/tokenizers/nerdstash_v2/encode', str); - case tokenizers.MISTRAL: - return getTextTokensFromServer('/api/tokenizers/mistral/encode', str); - case tokenizers.YI: - return getTextTokensFromServer('/api/tokenizers/yi/encode', str); - case tokenizers.OPENAI: { - const model = getTokenizerModel(); - return getTextTokensFromServer('/api/tokenizers/openai/encode', str, model); - } case tokenizers.API_TEXTGENERATIONWEBUI: - return getTextTokensFromTextgenAPI('/api/tokenizers/textgenerationwebui/encode', str); - default: - console.warn('Calling getTextTokens with unsupported tokenizer type', tokenizerType); - return []; + return getTextTokensFromTextgenAPI(str); + default: { + const tokenizerEndpoints = TOKENIZER_URLS[tokenizerType]; + if (!tokenizerEndpoints) { + console.warn('Unknown tokenizer type', tokenizerType); + return []; + } + let endpointUrl = tokenizerEndpoints.encode; + if (!endpointUrl) { + console.warn('This tokenizer type does not support encoding', tokenizerType); + return []; + } + if (tokenizerType === tokenizers.OPENAI) { + endpointUrl += `?model=${getTokenizerModel()}`; + } + return getTextTokensFromServer(endpointUrl, str); + } } } @@ -628,27 +651,20 @@ export function getTextTokens(tokenizerType, str) { * @param {number[]} ids Array of token ids */ export function decodeTextTokens(tokenizerType, ids) { - switch (tokenizerType) { - case tokenizers.GPT2: - return decodeTextTokensFromServer('/api/tokenizers/gpt2/decode', ids); - case tokenizers.LLAMA: - return decodeTextTokensFromServer('/api/tokenizers/llama/decode', ids); - case tokenizers.NERD: - return decodeTextTokensFromServer('/api/tokenizers/nerdstash/decode', ids); - case tokenizers.NERD2: - return decodeTextTokensFromServer('/api/tokenizers/nerdstash_v2/decode', ids); - case tokenizers.MISTRAL: - return decodeTextTokensFromServer('/api/tokenizers/mistral/decode', ids); - case tokenizers.YI: - return decodeTextTokensFromServer('/api/tokenizers/yi/decode', ids); - case tokenizers.OPENAI: { - const model = getTokenizerModel(); - return decodeTextTokensFromServer('/api/tokenizers/openai/decode', ids, model); - } - default: - console.warn('Calling decodeTextTokens with unsupported tokenizer type', tokenizerType); - return ''; + const tokenizerEndpoints = TOKENIZER_URLS[tokenizerType]; + if (!tokenizerEndpoints) { + console.warn('Unknown tokenizer type', tokenizerType); + return []; } + let endpointUrl = tokenizerEndpoints.decode; + if (!endpointUrl) { + console.warn('This tokenizer type does not support decoding', tokenizerType); + return []; + } + if (tokenizerType === tokenizers.OPENAI) { + endpointUrl += `?model=${getTokenizerModel()}`; + } + return decodeTextTokensFromServer(endpointUrl, ids); } export async function initTokenizers() {