Clean up tokenizer API code

Store the URLs for each tokenizer's action in one place at the top of
the file, instead of in a bunch of switch-cases. The URLs for the
textgen and Kobold APIs don't change and hence don't need to be
function arguments.
This commit is contained in:
valadaptive 2023-12-09 20:48:41 -05:00
parent 09465fbb97
commit 2f2cd197cc
1 changed files with 95 additions and 79 deletions

View File

@ -34,6 +34,51 @@ export const SENTENCEPIECE_TOKENIZERS = [
//tokenizers.NERD2, //tokenizers.NERD2,
]; ];
const TOKENIZER_URLS = {
[tokenizers.GPT2]: {
encode: '/api/tokenizers/gpt2/encode',
decode: '/api/tokenizers/gpt2/decode',
count: '/api/tokenizers/gpt2/encode',
},
[tokenizers.OPENAI]: {
encode: '/api/tokenizers/openai/encode',
decode: '/api/tokenizers/openai/decode',
count: '/api/tokenizers/openai/encode',
},
[tokenizers.LLAMA]: {
encode: '/api/tokenizers/llama/encode',
decode: '/api/tokenizers/llama/decode',
count: '/api/tokenizers/llama/encode',
},
[tokenizers.NERD]: {
encode: '/api/tokenizers/nerdstash/encode',
decode: '/api/tokenizers/nerdstash/decode',
count: '/api/tokenizers/nerdstash/encode',
},
[tokenizers.NERD2]: {
encode: '/api/tokenizers/nerdstash_v2/encode',
decode: '/api/tokenizers/nerdstash_v2/decode',
count: '/api/tokenizers/nerdstash_v2/encode',
},
[tokenizers.API_KOBOLD]: {
count: '/api/tokenizers/remote/kobold/count',
},
[tokenizers.MISTRAL]: {
encode: '/api/tokenizers/mistral/encode',
decode: '/api/tokenizers/mistral/decode',
count: '/api/tokenizers/mistral/encode',
},
[tokenizers.YI]: {
encode: '/api/tokenizers/yi/encode',
decode: '/api/tokenizers/yi/decode',
count: '/api/tokenizers/yi/encode',
},
[tokenizers.API_TEXTGENERATIONWEBUI]: {
encode: '/api/tokenizers/remote/textgenerationwebui/encode',
count: '/api/tokenizers/remote/textgenerationwebui/encode',
},
};
const objectStore = new localforage.createInstance({ name: 'SillyTavern_ChatCompletions' }); const objectStore = new localforage.createInstance({ name: 'SillyTavern_ChatCompletions' });
let tokenCache = {}; let tokenCache = {};
@ -158,28 +203,21 @@ export function getTokenizerBestMatch(forApi) {
* @returns {number} Token count. * @returns {number} Token count.
*/ */
function callTokenizer(type, str, padding) { function callTokenizer(type, str, padding) {
if (type === tokenizers.NONE) return guesstimate(str) + padding;
switch (type) { switch (type) {
case tokenizers.NONE:
return guesstimate(str) + padding;
case tokenizers.GPT2:
return countTokensFromServer('/api/tokenizers/gpt2/encode', str, padding);
case tokenizers.LLAMA:
return countTokensFromServer('/api/tokenizers/llama/encode', str, padding);
case tokenizers.NERD:
return countTokensFromServer('/api/tokenizers/nerdstash/encode', str, padding);
case tokenizers.NERD2:
return countTokensFromServer('/api/tokenizers/nerdstash_v2/encode', str, padding);
case tokenizers.MISTRAL:
return countTokensFromServer('/api/tokenizers/mistral/encode', str, padding);
case tokenizers.YI:
return countTokensFromServer('/api/tokenizers/yi/encode', str, padding);
case tokenizers.API_KOBOLD: case tokenizers.API_KOBOLD:
return countTokensFromKoboldAPI('/api/tokenizers/remote/kobold/count', str, padding); return countTokensFromKoboldAPI(str, padding);
case tokenizers.API_TEXTGENERATIONWEBUI: case tokenizers.API_TEXTGENERATIONWEBUI:
return countTokensFromTextgenAPI('/api/tokenizers/remote/textgenerationwebui/encode', str, padding); return countTokensFromTextgenAPI(str, padding);
default: default: {
console.warn('Unknown tokenizer type', type); const endpointUrl = TOKENIZER_URLS[type]?.count;
return callTokenizer(tokenizers.NONE, str, padding); if (!endpointUrl) {
console.warn('Unknown tokenizer type', type);
return callTokenizer(tokenizers.NONE, str, padding);
}
return countTokensFromServer(endpointUrl, str, padding);
}
} }
} }
@ -425,18 +463,17 @@ function countTokensFromServer(endpoint, str, padding) {
/** /**
* Count tokens using the AI provider's API. * Count tokens using the AI provider's API.
* @param {string} endpoint API endpoint.
* @param {string} str String to tokenize. * @param {string} str String to tokenize.
* @param {number} padding Number of padding tokens. * @param {number} padding Number of padding tokens.
* @returns {number} Token count with padding. * @returns {number} Token count with padding.
*/ */
function countTokensFromKoboldAPI(endpoint, str, padding) { function countTokensFromKoboldAPI(str, padding) {
let tokenCount = 0; let tokenCount = 0;
jQuery.ajax({ jQuery.ajax({
async: false, async: false,
type: 'POST', type: 'POST',
url: endpoint, url: TOKENIZER_URLS[tokenizers.API_KOBOLD].count,
data: JSON.stringify({ data: JSON.stringify({
text: str, text: str,
url: api_server, url: api_server,
@ -468,18 +505,17 @@ function getTextgenAPITokenizationParams(str) {
/** /**
* Count tokens using the AI provider's API. * Count tokens using the AI provider's API.
* @param {string} endpoint API endpoint.
* @param {string} str String to tokenize. * @param {string} str String to tokenize.
* @param {number} padding Number of padding tokens. * @param {number} padding Number of padding tokens.
* @returns {number} Token count with padding. * @returns {number} Token count with padding.
*/ */
function countTokensFromTextgenAPI(endpoint, str, padding) { function countTokensFromTextgenAPI(str, padding) {
let tokenCount = 0; let tokenCount = 0;
jQuery.ajax({ jQuery.ajax({
async: false, async: false,
type: 'POST', type: 'POST',
url: endpoint, url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].count,
data: JSON.stringify(getTextgenAPITokenizationParams(str)), data: JSON.stringify(getTextgenAPITokenizationParams(str)),
dataType: 'json', dataType: 'json',
contentType: 'application/json', contentType: 'application/json',
@ -515,14 +551,9 @@ function apiFailureTokenCount(str) {
* Calls the underlying tokenizer model to encode a string to tokens. * Calls the underlying tokenizer model to encode a string to tokens.
* @param {string} endpoint API endpoint. * @param {string} endpoint API endpoint.
* @param {string} str String to tokenize. * @param {string} str String to tokenize.
* @param {string} model Tokenizer model.
* @returns {number[]} Array of token ids. * @returns {number[]} Array of token ids.
*/ */
function getTextTokensFromServer(endpoint, str, model = '') { function getTextTokensFromServer(endpoint, str) {
if (model) {
endpoint += `?model=${model}`;
}
let ids = []; let ids = [];
jQuery.ajax({ jQuery.ajax({
async: false, async: false,
@ -545,16 +576,15 @@ function getTextTokensFromServer(endpoint, str, model = '') {
/** /**
* Calls the AI provider's tokenize API to encode a string to tokens. * Calls the AI provider's tokenize API to encode a string to tokens.
* @param {string} endpoint API endpoint.
* @param {string} str String to tokenize. * @param {string} str String to tokenize.
* @returns {number[]} Array of token ids. * @returns {number[]} Array of token ids.
*/ */
function getTextTokensFromTextgenAPI(endpoint, str) { function getTextTokensFromTextgenAPI(str) {
let ids = []; let ids = [];
jQuery.ajax({ jQuery.ajax({
async: false, async: false,
type: 'POST', type: 'POST',
url: endpoint, url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].encode,
data: JSON.stringify(getTextgenAPITokenizationParams(str)), data: JSON.stringify(getTextgenAPITokenizationParams(str)),
dataType: 'json', dataType: 'json',
contentType: 'application/json', contentType: 'application/json',
@ -570,11 +600,7 @@ function getTextTokensFromTextgenAPI(endpoint, str) {
* @param {string} endpoint API endpoint. * @param {string} endpoint API endpoint.
* @param {number[]} ids Array of token ids * @param {number[]} ids Array of token ids
*/ */
function decodeTextTokensFromServer(endpoint, ids, model = '') { function decodeTextTokensFromServer(endpoint, ids) {
if (model) {
endpoint += `?model=${model}`;
}
let text = ''; let text = '';
jQuery.ajax({ jQuery.ajax({
async: false, async: false,
@ -598,27 +624,24 @@ function decodeTextTokensFromServer(endpoint, ids, model = '') {
*/ */
export function getTextTokens(tokenizerType, str) { export function getTextTokens(tokenizerType, str) {
switch (tokenizerType) { switch (tokenizerType) {
case tokenizers.GPT2:
return getTextTokensFromServer('/api/tokenizers/gpt2/encode', str);
case tokenizers.LLAMA:
return getTextTokensFromServer('/api/tokenizers/llama/encode', str);
case tokenizers.NERD:
return getTextTokensFromServer('/api/tokenizers/nerdstash/encode', str);
case tokenizers.NERD2:
return getTextTokensFromServer('/api/tokenizers/nerdstash_v2/encode', str);
case tokenizers.MISTRAL:
return getTextTokensFromServer('/api/tokenizers/mistral/encode', str);
case tokenizers.YI:
return getTextTokensFromServer('/api/tokenizers/yi/encode', str);
case tokenizers.OPENAI: {
const model = getTokenizerModel();
return getTextTokensFromServer('/api/tokenizers/openai/encode', str, model);
}
case tokenizers.API_TEXTGENERATIONWEBUI: case tokenizers.API_TEXTGENERATIONWEBUI:
return getTextTokensFromTextgenAPI('/api/tokenizers/textgenerationwebui/encode', str); return getTextTokensFromTextgenAPI(str);
default: default: {
console.warn('Calling getTextTokens with unsupported tokenizer type', tokenizerType); const tokenizerEndpoints = TOKENIZER_URLS[tokenizerType];
return []; if (!tokenizerEndpoints) {
console.warn('Unknown tokenizer type', tokenizerType);
return [];
}
let endpointUrl = tokenizerEndpoints.encode;
if (!endpointUrl) {
console.warn('This tokenizer type does not support encoding', tokenizerType);
return [];
}
if (tokenizerType === tokenizers.OPENAI) {
endpointUrl += `?model=${getTokenizerModel()}`;
}
return getTextTokensFromServer(endpointUrl, str);
}
} }
} }
@ -628,27 +651,20 @@ export function getTextTokens(tokenizerType, str) {
* @param {number[]} ids Array of token ids * @param {number[]} ids Array of token ids
*/ */
export function decodeTextTokens(tokenizerType, ids) { export function decodeTextTokens(tokenizerType, ids) {
switch (tokenizerType) { const tokenizerEndpoints = TOKENIZER_URLS[tokenizerType];
case tokenizers.GPT2: if (!tokenizerEndpoints) {
return decodeTextTokensFromServer('/api/tokenizers/gpt2/decode', ids); console.warn('Unknown tokenizer type', tokenizerType);
case tokenizers.LLAMA: return [];
return decodeTextTokensFromServer('/api/tokenizers/llama/decode', ids);
case tokenizers.NERD:
return decodeTextTokensFromServer('/api/tokenizers/nerdstash/decode', ids);
case tokenizers.NERD2:
return decodeTextTokensFromServer('/api/tokenizers/nerdstash_v2/decode', ids);
case tokenizers.MISTRAL:
return decodeTextTokensFromServer('/api/tokenizers/mistral/decode', ids);
case tokenizers.YI:
return decodeTextTokensFromServer('/api/tokenizers/yi/decode', ids);
case tokenizers.OPENAI: {
const model = getTokenizerModel();
return decodeTextTokensFromServer('/api/tokenizers/openai/decode', ids, model);
}
default:
console.warn('Calling decodeTextTokens with unsupported tokenizer type', tokenizerType);
return '';
} }
let endpointUrl = tokenizerEndpoints.decode;
if (!endpointUrl) {
console.warn('This tokenizer type does not support decoding', tokenizerType);
return [];
}
if (tokenizerType === tokenizers.OPENAI) {
endpointUrl += `?model=${getTokenizerModel()}`;
}
return decodeTextTokensFromServer(endpointUrl, ids);
} }
export async function initTokenizers() { export async function initTokenizers() {