Clean up tokenizer API code

Store the URLs for each tokenizer's action in one place at the top of
the file, instead of in a bunch of switch-cases. The URLs for the
textgen and Kobold APIs don't change and hence don't need to be
function arguments.
This commit is contained in:
valadaptive 2023-12-09 20:48:41 -05:00
parent 09465fbb97
commit 2f2cd197cc
1 changed files with 95 additions and 79 deletions

View File

@ -34,6 +34,51 @@ export const SENTENCEPIECE_TOKENIZERS = [
//tokenizers.NERD2,
];
const TOKENIZER_URLS = {
[tokenizers.GPT2]: {
encode: '/api/tokenizers/gpt2/encode',
decode: '/api/tokenizers/gpt2/decode',
count: '/api/tokenizers/gpt2/encode',
},
[tokenizers.OPENAI]: {
encode: '/api/tokenizers/openai/encode',
decode: '/api/tokenizers/openai/decode',
count: '/api/tokenizers/openai/encode',
},
[tokenizers.LLAMA]: {
encode: '/api/tokenizers/llama/encode',
decode: '/api/tokenizers/llama/decode',
count: '/api/tokenizers/llama/encode',
},
[tokenizers.NERD]: {
encode: '/api/tokenizers/nerdstash/encode',
decode: '/api/tokenizers/nerdstash/decode',
count: '/api/tokenizers/nerdstash/encode',
},
[tokenizers.NERD2]: {
encode: '/api/tokenizers/nerdstash_v2/encode',
decode: '/api/tokenizers/nerdstash_v2/decode',
count: '/api/tokenizers/nerdstash_v2/encode',
},
[tokenizers.API_KOBOLD]: {
count: '/api/tokenizers/remote/kobold/count',
},
[tokenizers.MISTRAL]: {
encode: '/api/tokenizers/mistral/encode',
decode: '/api/tokenizers/mistral/decode',
count: '/api/tokenizers/mistral/encode',
},
[tokenizers.YI]: {
encode: '/api/tokenizers/yi/encode',
decode: '/api/tokenizers/yi/decode',
count: '/api/tokenizers/yi/encode',
},
[tokenizers.API_TEXTGENERATIONWEBUI]: {
encode: '/api/tokenizers/remote/textgenerationwebui/encode',
count: '/api/tokenizers/remote/textgenerationwebui/encode',
},
};
const objectStore = new localforage.createInstance({ name: 'SillyTavern_ChatCompletions' });
let tokenCache = {};
@ -158,28 +203,21 @@ export function getTokenizerBestMatch(forApi) {
* @returns {number} Token count.
*/
function callTokenizer(type, str, padding) {
if (type === tokenizers.NONE) return guesstimate(str) + padding;
switch (type) {
case tokenizers.NONE:
return guesstimate(str) + padding;
case tokenizers.GPT2:
return countTokensFromServer('/api/tokenizers/gpt2/encode', str, padding);
case tokenizers.LLAMA:
return countTokensFromServer('/api/tokenizers/llama/encode', str, padding);
case tokenizers.NERD:
return countTokensFromServer('/api/tokenizers/nerdstash/encode', str, padding);
case tokenizers.NERD2:
return countTokensFromServer('/api/tokenizers/nerdstash_v2/encode', str, padding);
case tokenizers.MISTRAL:
return countTokensFromServer('/api/tokenizers/mistral/encode', str, padding);
case tokenizers.YI:
return countTokensFromServer('/api/tokenizers/yi/encode', str, padding);
case tokenizers.API_KOBOLD:
return countTokensFromKoboldAPI('/api/tokenizers/remote/kobold/count', str, padding);
return countTokensFromKoboldAPI(str, padding);
case tokenizers.API_TEXTGENERATIONWEBUI:
return countTokensFromTextgenAPI('/api/tokenizers/remote/textgenerationwebui/encode', str, padding);
default:
console.warn('Unknown tokenizer type', type);
return callTokenizer(tokenizers.NONE, str, padding);
return countTokensFromTextgenAPI(str, padding);
default: {
const endpointUrl = TOKENIZER_URLS[type]?.count;
if (!endpointUrl) {
console.warn('Unknown tokenizer type', type);
return callTokenizer(tokenizers.NONE, str, padding);
}
return countTokensFromServer(endpointUrl, str, padding);
}
}
}
@ -425,18 +463,17 @@ function countTokensFromServer(endpoint, str, padding) {
/**
* Count tokens using the AI provider's API.
* @param {string} endpoint API endpoint.
* @param {string} str String to tokenize.
* @param {number} padding Number of padding tokens.
* @returns {number} Token count with padding.
*/
function countTokensFromKoboldAPI(endpoint, str, padding) {
function countTokensFromKoboldAPI(str, padding) {
let tokenCount = 0;
jQuery.ajax({
async: false,
type: 'POST',
url: endpoint,
url: TOKENIZER_URLS[tokenizers.API_KOBOLD].count,
data: JSON.stringify({
text: str,
url: api_server,
@ -468,18 +505,17 @@ function getTextgenAPITokenizationParams(str) {
/**
* Count tokens using the AI provider's API.
* @param {string} endpoint API endpoint.
* @param {string} str String to tokenize.
* @param {number} padding Number of padding tokens.
* @returns {number} Token count with padding.
*/
function countTokensFromTextgenAPI(endpoint, str, padding) {
function countTokensFromTextgenAPI(str, padding) {
let tokenCount = 0;
jQuery.ajax({
async: false,
type: 'POST',
url: endpoint,
url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].count,
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
dataType: 'json',
contentType: 'application/json',
@ -515,14 +551,9 @@ function apiFailureTokenCount(str) {
* Calls the underlying tokenizer model to encode a string to tokens.
* @param {string} endpoint API endpoint.
* @param {string} str String to tokenize.
* @param {string} model Tokenizer model.
* @returns {number[]} Array of token ids.
*/
function getTextTokensFromServer(endpoint, str, model = '') {
if (model) {
endpoint += `?model=${model}`;
}
function getTextTokensFromServer(endpoint, str) {
let ids = [];
jQuery.ajax({
async: false,
@ -545,16 +576,15 @@ function getTextTokensFromServer(endpoint, str, model = '') {
/**
* Calls the AI provider's tokenize API to encode a string to tokens.
* @param {string} endpoint API endpoint.
* @param {string} str String to tokenize.
* @returns {number[]} Array of token ids.
*/
function getTextTokensFromTextgenAPI(endpoint, str) {
function getTextTokensFromTextgenAPI(str) {
let ids = [];
jQuery.ajax({
async: false,
type: 'POST',
url: endpoint,
url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].encode,
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
dataType: 'json',
contentType: 'application/json',
@ -570,11 +600,7 @@ function getTextTokensFromTextgenAPI(endpoint, str) {
* @param {string} endpoint API endpoint.
* @param {number[]} ids Array of token ids
*/
function decodeTextTokensFromServer(endpoint, ids, model = '') {
if (model) {
endpoint += `?model=${model}`;
}
function decodeTextTokensFromServer(endpoint, ids) {
let text = '';
jQuery.ajax({
async: false,
@ -598,27 +624,24 @@ function decodeTextTokensFromServer(endpoint, ids, model = '') {
*/
export function getTextTokens(tokenizerType, str) {
switch (tokenizerType) {
case tokenizers.GPT2:
return getTextTokensFromServer('/api/tokenizers/gpt2/encode', str);
case tokenizers.LLAMA:
return getTextTokensFromServer('/api/tokenizers/llama/encode', str);
case tokenizers.NERD:
return getTextTokensFromServer('/api/tokenizers/nerdstash/encode', str);
case tokenizers.NERD2:
return getTextTokensFromServer('/api/tokenizers/nerdstash_v2/encode', str);
case tokenizers.MISTRAL:
return getTextTokensFromServer('/api/tokenizers/mistral/encode', str);
case tokenizers.YI:
return getTextTokensFromServer('/api/tokenizers/yi/encode', str);
case tokenizers.OPENAI: {
const model = getTokenizerModel();
return getTextTokensFromServer('/api/tokenizers/openai/encode', str, model);
}
case tokenizers.API_TEXTGENERATIONWEBUI:
return getTextTokensFromTextgenAPI('/api/tokenizers/textgenerationwebui/encode', str);
default:
console.warn('Calling getTextTokens with unsupported tokenizer type', tokenizerType);
return [];
return getTextTokensFromTextgenAPI(str);
default: {
const tokenizerEndpoints = TOKENIZER_URLS[tokenizerType];
if (!tokenizerEndpoints) {
console.warn('Unknown tokenizer type', tokenizerType);
return [];
}
let endpointUrl = tokenizerEndpoints.encode;
if (!endpointUrl) {
console.warn('This tokenizer type does not support encoding', tokenizerType);
return [];
}
if (tokenizerType === tokenizers.OPENAI) {
endpointUrl += `?model=${getTokenizerModel()}`;
}
return getTextTokensFromServer(endpointUrl, str);
}
}
}
@ -628,27 +651,20 @@ export function getTextTokens(tokenizerType, str) {
* @param {number[]} ids Array of token ids
*/
export function decodeTextTokens(tokenizerType, ids) {
switch (tokenizerType) {
case tokenizers.GPT2:
return decodeTextTokensFromServer('/api/tokenizers/gpt2/decode', ids);
case tokenizers.LLAMA:
return decodeTextTokensFromServer('/api/tokenizers/llama/decode', ids);
case tokenizers.NERD:
return decodeTextTokensFromServer('/api/tokenizers/nerdstash/decode', ids);
case tokenizers.NERD2:
return decodeTextTokensFromServer('/api/tokenizers/nerdstash_v2/decode', ids);
case tokenizers.MISTRAL:
return decodeTextTokensFromServer('/api/tokenizers/mistral/decode', ids);
case tokenizers.YI:
return decodeTextTokensFromServer('/api/tokenizers/yi/decode', ids);
case tokenizers.OPENAI: {
const model = getTokenizerModel();
return decodeTextTokensFromServer('/api/tokenizers/openai/decode', ids, model);
}
default:
console.warn('Calling decodeTextTokens with unsupported tokenizer type', tokenizerType);
return '';
const tokenizerEndpoints = TOKENIZER_URLS[tokenizerType];
if (!tokenizerEndpoints) {
console.warn('Unknown tokenizer type', tokenizerType);
return [];
}
let endpointUrl = tokenizerEndpoints.decode;
if (!endpointUrl) {
console.warn('This tokenizer type does not support decoding', tokenizerType);
return [];
}
if (tokenizerType === tokenizers.OPENAI) {
endpointUrl += `?model=${getTokenizerModel()}`;
}
return decodeTextTokensFromServer(endpointUrl, ids);
}
export async function initTokenizers() {