Clean up tokenizer API code
Store the URLs for each tokenizer's action in one place at the top of the file, instead of in a bunch of switch-cases. The URLs for the textgen and Kobold APIs don't change and hence don't need to be function arguments.
This commit is contained in:
parent
09465fbb97
commit
2f2cd197cc
|
@ -34,6 +34,51 @@ export const SENTENCEPIECE_TOKENIZERS = [
|
|||
//tokenizers.NERD2,
|
||||
];
|
||||
|
||||
const TOKENIZER_URLS = {
|
||||
[tokenizers.GPT2]: {
|
||||
encode: '/api/tokenizers/gpt2/encode',
|
||||
decode: '/api/tokenizers/gpt2/decode',
|
||||
count: '/api/tokenizers/gpt2/encode',
|
||||
},
|
||||
[tokenizers.OPENAI]: {
|
||||
encode: '/api/tokenizers/openai/encode',
|
||||
decode: '/api/tokenizers/openai/decode',
|
||||
count: '/api/tokenizers/openai/encode',
|
||||
},
|
||||
[tokenizers.LLAMA]: {
|
||||
encode: '/api/tokenizers/llama/encode',
|
||||
decode: '/api/tokenizers/llama/decode',
|
||||
count: '/api/tokenizers/llama/encode',
|
||||
},
|
||||
[tokenizers.NERD]: {
|
||||
encode: '/api/tokenizers/nerdstash/encode',
|
||||
decode: '/api/tokenizers/nerdstash/decode',
|
||||
count: '/api/tokenizers/nerdstash/encode',
|
||||
},
|
||||
[tokenizers.NERD2]: {
|
||||
encode: '/api/tokenizers/nerdstash_v2/encode',
|
||||
decode: '/api/tokenizers/nerdstash_v2/decode',
|
||||
count: '/api/tokenizers/nerdstash_v2/encode',
|
||||
},
|
||||
[tokenizers.API_KOBOLD]: {
|
||||
count: '/api/tokenizers/remote/kobold/count',
|
||||
},
|
||||
[tokenizers.MISTRAL]: {
|
||||
encode: '/api/tokenizers/mistral/encode',
|
||||
decode: '/api/tokenizers/mistral/decode',
|
||||
count: '/api/tokenizers/mistral/encode',
|
||||
},
|
||||
[tokenizers.YI]: {
|
||||
encode: '/api/tokenizers/yi/encode',
|
||||
decode: '/api/tokenizers/yi/decode',
|
||||
count: '/api/tokenizers/yi/encode',
|
||||
},
|
||||
[tokenizers.API_TEXTGENERATIONWEBUI]: {
|
||||
encode: '/api/tokenizers/remote/textgenerationwebui/encode',
|
||||
count: '/api/tokenizers/remote/textgenerationwebui/encode',
|
||||
},
|
||||
};
|
||||
|
||||
const objectStore = new localforage.createInstance({ name: 'SillyTavern_ChatCompletions' });
|
||||
|
||||
let tokenCache = {};
|
||||
|
@ -158,28 +203,21 @@ export function getTokenizerBestMatch(forApi) {
|
|||
* @returns {number} Token count.
|
||||
*/
|
||||
function callTokenizer(type, str, padding) {
|
||||
if (type === tokenizers.NONE) return guesstimate(str) + padding;
|
||||
|
||||
switch (type) {
|
||||
case tokenizers.NONE:
|
||||
return guesstimate(str) + padding;
|
||||
case tokenizers.GPT2:
|
||||
return countTokensFromServer('/api/tokenizers/gpt2/encode', str, padding);
|
||||
case tokenizers.LLAMA:
|
||||
return countTokensFromServer('/api/tokenizers/llama/encode', str, padding);
|
||||
case tokenizers.NERD:
|
||||
return countTokensFromServer('/api/tokenizers/nerdstash/encode', str, padding);
|
||||
case tokenizers.NERD2:
|
||||
return countTokensFromServer('/api/tokenizers/nerdstash_v2/encode', str, padding);
|
||||
case tokenizers.MISTRAL:
|
||||
return countTokensFromServer('/api/tokenizers/mistral/encode', str, padding);
|
||||
case tokenizers.YI:
|
||||
return countTokensFromServer('/api/tokenizers/yi/encode', str, padding);
|
||||
case tokenizers.API_KOBOLD:
|
||||
return countTokensFromKoboldAPI('/api/tokenizers/remote/kobold/count', str, padding);
|
||||
return countTokensFromKoboldAPI(str, padding);
|
||||
case tokenizers.API_TEXTGENERATIONWEBUI:
|
||||
return countTokensFromTextgenAPI('/api/tokenizers/remote/textgenerationwebui/encode', str, padding);
|
||||
default:
|
||||
console.warn('Unknown tokenizer type', type);
|
||||
return callTokenizer(tokenizers.NONE, str, padding);
|
||||
return countTokensFromTextgenAPI(str, padding);
|
||||
default: {
|
||||
const endpointUrl = TOKENIZER_URLS[type]?.count;
|
||||
if (!endpointUrl) {
|
||||
console.warn('Unknown tokenizer type', type);
|
||||
return callTokenizer(tokenizers.NONE, str, padding);
|
||||
}
|
||||
return countTokensFromServer(endpointUrl, str, padding);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -425,18 +463,17 @@ function countTokensFromServer(endpoint, str, padding) {
|
|||
|
||||
/**
|
||||
* Count tokens using the AI provider's API.
|
||||
* @param {string} endpoint API endpoint.
|
||||
* @param {string} str String to tokenize.
|
||||
* @param {number} padding Number of padding tokens.
|
||||
* @returns {number} Token count with padding.
|
||||
*/
|
||||
function countTokensFromKoboldAPI(endpoint, str, padding) {
|
||||
function countTokensFromKoboldAPI(str, padding) {
|
||||
let tokenCount = 0;
|
||||
|
||||
jQuery.ajax({
|
||||
async: false,
|
||||
type: 'POST',
|
||||
url: endpoint,
|
||||
url: TOKENIZER_URLS[tokenizers.API_KOBOLD].count,
|
||||
data: JSON.stringify({
|
||||
text: str,
|
||||
url: api_server,
|
||||
|
@ -468,18 +505,17 @@ function getTextgenAPITokenizationParams(str) {
|
|||
|
||||
/**
|
||||
* Count tokens using the AI provider's API.
|
||||
* @param {string} endpoint API endpoint.
|
||||
* @param {string} str String to tokenize.
|
||||
* @param {number} padding Number of padding tokens.
|
||||
* @returns {number} Token count with padding.
|
||||
*/
|
||||
function countTokensFromTextgenAPI(endpoint, str, padding) {
|
||||
function countTokensFromTextgenAPI(str, padding) {
|
||||
let tokenCount = 0;
|
||||
|
||||
jQuery.ajax({
|
||||
async: false,
|
||||
type: 'POST',
|
||||
url: endpoint,
|
||||
url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].count,
|
||||
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
|
||||
dataType: 'json',
|
||||
contentType: 'application/json',
|
||||
|
@ -515,14 +551,9 @@ function apiFailureTokenCount(str) {
|
|||
* Calls the underlying tokenizer model to encode a string to tokens.
|
||||
* @param {string} endpoint API endpoint.
|
||||
* @param {string} str String to tokenize.
|
||||
* @param {string} model Tokenizer model.
|
||||
* @returns {number[]} Array of token ids.
|
||||
*/
|
||||
function getTextTokensFromServer(endpoint, str, model = '') {
|
||||
if (model) {
|
||||
endpoint += `?model=${model}`;
|
||||
}
|
||||
|
||||
function getTextTokensFromServer(endpoint, str) {
|
||||
let ids = [];
|
||||
jQuery.ajax({
|
||||
async: false,
|
||||
|
@ -545,16 +576,15 @@ function getTextTokensFromServer(endpoint, str, model = '') {
|
|||
|
||||
/**
|
||||
* Calls the AI provider's tokenize API to encode a string to tokens.
|
||||
* @param {string} endpoint API endpoint.
|
||||
* @param {string} str String to tokenize.
|
||||
* @returns {number[]} Array of token ids.
|
||||
*/
|
||||
function getTextTokensFromTextgenAPI(endpoint, str) {
|
||||
function getTextTokensFromTextgenAPI(str) {
|
||||
let ids = [];
|
||||
jQuery.ajax({
|
||||
async: false,
|
||||
type: 'POST',
|
||||
url: endpoint,
|
||||
url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].encode,
|
||||
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
|
||||
dataType: 'json',
|
||||
contentType: 'application/json',
|
||||
|
@ -570,11 +600,7 @@ function getTextTokensFromTextgenAPI(endpoint, str) {
|
|||
* @param {string} endpoint API endpoint.
|
||||
* @param {number[]} ids Array of token ids
|
||||
*/
|
||||
function decodeTextTokensFromServer(endpoint, ids, model = '') {
|
||||
if (model) {
|
||||
endpoint += `?model=${model}`;
|
||||
}
|
||||
|
||||
function decodeTextTokensFromServer(endpoint, ids) {
|
||||
let text = '';
|
||||
jQuery.ajax({
|
||||
async: false,
|
||||
|
@ -598,27 +624,24 @@ function decodeTextTokensFromServer(endpoint, ids, model = '') {
|
|||
*/
|
||||
export function getTextTokens(tokenizerType, str) {
|
||||
switch (tokenizerType) {
|
||||
case tokenizers.GPT2:
|
||||
return getTextTokensFromServer('/api/tokenizers/gpt2/encode', str);
|
||||
case tokenizers.LLAMA:
|
||||
return getTextTokensFromServer('/api/tokenizers/llama/encode', str);
|
||||
case tokenizers.NERD:
|
||||
return getTextTokensFromServer('/api/tokenizers/nerdstash/encode', str);
|
||||
case tokenizers.NERD2:
|
||||
return getTextTokensFromServer('/api/tokenizers/nerdstash_v2/encode', str);
|
||||
case tokenizers.MISTRAL:
|
||||
return getTextTokensFromServer('/api/tokenizers/mistral/encode', str);
|
||||
case tokenizers.YI:
|
||||
return getTextTokensFromServer('/api/tokenizers/yi/encode', str);
|
||||
case tokenizers.OPENAI: {
|
||||
const model = getTokenizerModel();
|
||||
return getTextTokensFromServer('/api/tokenizers/openai/encode', str, model);
|
||||
}
|
||||
case tokenizers.API_TEXTGENERATIONWEBUI:
|
||||
return getTextTokensFromTextgenAPI('/api/tokenizers/textgenerationwebui/encode', str);
|
||||
default:
|
||||
console.warn('Calling getTextTokens with unsupported tokenizer type', tokenizerType);
|
||||
return [];
|
||||
return getTextTokensFromTextgenAPI(str);
|
||||
default: {
|
||||
const tokenizerEndpoints = TOKENIZER_URLS[tokenizerType];
|
||||
if (!tokenizerEndpoints) {
|
||||
console.warn('Unknown tokenizer type', tokenizerType);
|
||||
return [];
|
||||
}
|
||||
let endpointUrl = tokenizerEndpoints.encode;
|
||||
if (!endpointUrl) {
|
||||
console.warn('This tokenizer type does not support encoding', tokenizerType);
|
||||
return [];
|
||||
}
|
||||
if (tokenizerType === tokenizers.OPENAI) {
|
||||
endpointUrl += `?model=${getTokenizerModel()}`;
|
||||
}
|
||||
return getTextTokensFromServer(endpointUrl, str);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -628,27 +651,20 @@ export function getTextTokens(tokenizerType, str) {
|
|||
* @param {number[]} ids Array of token ids
|
||||
*/
|
||||
export function decodeTextTokens(tokenizerType, ids) {
|
||||
switch (tokenizerType) {
|
||||
case tokenizers.GPT2:
|
||||
return decodeTextTokensFromServer('/api/tokenizers/gpt2/decode', ids);
|
||||
case tokenizers.LLAMA:
|
||||
return decodeTextTokensFromServer('/api/tokenizers/llama/decode', ids);
|
||||
case tokenizers.NERD:
|
||||
return decodeTextTokensFromServer('/api/tokenizers/nerdstash/decode', ids);
|
||||
case tokenizers.NERD2:
|
||||
return decodeTextTokensFromServer('/api/tokenizers/nerdstash_v2/decode', ids);
|
||||
case tokenizers.MISTRAL:
|
||||
return decodeTextTokensFromServer('/api/tokenizers/mistral/decode', ids);
|
||||
case tokenizers.YI:
|
||||
return decodeTextTokensFromServer('/api/tokenizers/yi/decode', ids);
|
||||
case tokenizers.OPENAI: {
|
||||
const model = getTokenizerModel();
|
||||
return decodeTextTokensFromServer('/api/tokenizers/openai/decode', ids, model);
|
||||
}
|
||||
default:
|
||||
console.warn('Calling decodeTextTokens with unsupported tokenizer type', tokenizerType);
|
||||
return '';
|
||||
const tokenizerEndpoints = TOKENIZER_URLS[tokenizerType];
|
||||
if (!tokenizerEndpoints) {
|
||||
console.warn('Unknown tokenizer type', tokenizerType);
|
||||
return [];
|
||||
}
|
||||
let endpointUrl = tokenizerEndpoints.decode;
|
||||
if (!endpointUrl) {
|
||||
console.warn('This tokenizer type does not support decoding', tokenizerType);
|
||||
return [];
|
||||
}
|
||||
if (tokenizerType === tokenizers.OPENAI) {
|
||||
endpointUrl += `?model=${getTokenizerModel()}`;
|
||||
}
|
||||
return decodeTextTokensFromServer(endpointUrl, ids);
|
||||
}
|
||||
|
||||
export async function initTokenizers() {
|
||||
|
|
Loading…
Reference in New Issue