Clean up tokenizer API code
Store the URLs for each tokenizer's action in one place at the top of the file, instead of in a bunch of switch-cases. The URLs for the textgen and Kobold APIs don't change and hence don't need to be function arguments.
This commit is contained in:
parent
09465fbb97
commit
2f2cd197cc
|
@ -34,6 +34,51 @@ export const SENTENCEPIECE_TOKENIZERS = [
|
||||||
//tokenizers.NERD2,
|
//tokenizers.NERD2,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
const TOKENIZER_URLS = {
|
||||||
|
[tokenizers.GPT2]: {
|
||||||
|
encode: '/api/tokenizers/gpt2/encode',
|
||||||
|
decode: '/api/tokenizers/gpt2/decode',
|
||||||
|
count: '/api/tokenizers/gpt2/encode',
|
||||||
|
},
|
||||||
|
[tokenizers.OPENAI]: {
|
||||||
|
encode: '/api/tokenizers/openai/encode',
|
||||||
|
decode: '/api/tokenizers/openai/decode',
|
||||||
|
count: '/api/tokenizers/openai/encode',
|
||||||
|
},
|
||||||
|
[tokenizers.LLAMA]: {
|
||||||
|
encode: '/api/tokenizers/llama/encode',
|
||||||
|
decode: '/api/tokenizers/llama/decode',
|
||||||
|
count: '/api/tokenizers/llama/encode',
|
||||||
|
},
|
||||||
|
[tokenizers.NERD]: {
|
||||||
|
encode: '/api/tokenizers/nerdstash/encode',
|
||||||
|
decode: '/api/tokenizers/nerdstash/decode',
|
||||||
|
count: '/api/tokenizers/nerdstash/encode',
|
||||||
|
},
|
||||||
|
[tokenizers.NERD2]: {
|
||||||
|
encode: '/api/tokenizers/nerdstash_v2/encode',
|
||||||
|
decode: '/api/tokenizers/nerdstash_v2/decode',
|
||||||
|
count: '/api/tokenizers/nerdstash_v2/encode',
|
||||||
|
},
|
||||||
|
[tokenizers.API_KOBOLD]: {
|
||||||
|
count: '/api/tokenizers/remote/kobold/count',
|
||||||
|
},
|
||||||
|
[tokenizers.MISTRAL]: {
|
||||||
|
encode: '/api/tokenizers/mistral/encode',
|
||||||
|
decode: '/api/tokenizers/mistral/decode',
|
||||||
|
count: '/api/tokenizers/mistral/encode',
|
||||||
|
},
|
||||||
|
[tokenizers.YI]: {
|
||||||
|
encode: '/api/tokenizers/yi/encode',
|
||||||
|
decode: '/api/tokenizers/yi/decode',
|
||||||
|
count: '/api/tokenizers/yi/encode',
|
||||||
|
},
|
||||||
|
[tokenizers.API_TEXTGENERATIONWEBUI]: {
|
||||||
|
encode: '/api/tokenizers/remote/textgenerationwebui/encode',
|
||||||
|
count: '/api/tokenizers/remote/textgenerationwebui/encode',
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
const objectStore = new localforage.createInstance({ name: 'SillyTavern_ChatCompletions' });
|
const objectStore = new localforage.createInstance({ name: 'SillyTavern_ChatCompletions' });
|
||||||
|
|
||||||
let tokenCache = {};
|
let tokenCache = {};
|
||||||
|
@ -158,28 +203,21 @@ export function getTokenizerBestMatch(forApi) {
|
||||||
* @returns {number} Token count.
|
* @returns {number} Token count.
|
||||||
*/
|
*/
|
||||||
function callTokenizer(type, str, padding) {
|
function callTokenizer(type, str, padding) {
|
||||||
|
if (type === tokenizers.NONE) return guesstimate(str) + padding;
|
||||||
|
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case tokenizers.NONE:
|
|
||||||
return guesstimate(str) + padding;
|
|
||||||
case tokenizers.GPT2:
|
|
||||||
return countTokensFromServer('/api/tokenizers/gpt2/encode', str, padding);
|
|
||||||
case tokenizers.LLAMA:
|
|
||||||
return countTokensFromServer('/api/tokenizers/llama/encode', str, padding);
|
|
||||||
case tokenizers.NERD:
|
|
||||||
return countTokensFromServer('/api/tokenizers/nerdstash/encode', str, padding);
|
|
||||||
case tokenizers.NERD2:
|
|
||||||
return countTokensFromServer('/api/tokenizers/nerdstash_v2/encode', str, padding);
|
|
||||||
case tokenizers.MISTRAL:
|
|
||||||
return countTokensFromServer('/api/tokenizers/mistral/encode', str, padding);
|
|
||||||
case tokenizers.YI:
|
|
||||||
return countTokensFromServer('/api/tokenizers/yi/encode', str, padding);
|
|
||||||
case tokenizers.API_KOBOLD:
|
case tokenizers.API_KOBOLD:
|
||||||
return countTokensFromKoboldAPI('/api/tokenizers/remote/kobold/count', str, padding);
|
return countTokensFromKoboldAPI(str, padding);
|
||||||
case tokenizers.API_TEXTGENERATIONWEBUI:
|
case tokenizers.API_TEXTGENERATIONWEBUI:
|
||||||
return countTokensFromTextgenAPI('/api/tokenizers/remote/textgenerationwebui/encode', str, padding);
|
return countTokensFromTextgenAPI(str, padding);
|
||||||
default:
|
default: {
|
||||||
console.warn('Unknown tokenizer type', type);
|
const endpointUrl = TOKENIZER_URLS[type]?.count;
|
||||||
return callTokenizer(tokenizers.NONE, str, padding);
|
if (!endpointUrl) {
|
||||||
|
console.warn('Unknown tokenizer type', type);
|
||||||
|
return callTokenizer(tokenizers.NONE, str, padding);
|
||||||
|
}
|
||||||
|
return countTokensFromServer(endpointUrl, str, padding);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -425,18 +463,17 @@ function countTokensFromServer(endpoint, str, padding) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Count tokens using the AI provider's API.
|
* Count tokens using the AI provider's API.
|
||||||
* @param {string} endpoint API endpoint.
|
|
||||||
* @param {string} str String to tokenize.
|
* @param {string} str String to tokenize.
|
||||||
* @param {number} padding Number of padding tokens.
|
* @param {number} padding Number of padding tokens.
|
||||||
* @returns {number} Token count with padding.
|
* @returns {number} Token count with padding.
|
||||||
*/
|
*/
|
||||||
function countTokensFromKoboldAPI(endpoint, str, padding) {
|
function countTokensFromKoboldAPI(str, padding) {
|
||||||
let tokenCount = 0;
|
let tokenCount = 0;
|
||||||
|
|
||||||
jQuery.ajax({
|
jQuery.ajax({
|
||||||
async: false,
|
async: false,
|
||||||
type: 'POST',
|
type: 'POST',
|
||||||
url: endpoint,
|
url: TOKENIZER_URLS[tokenizers.API_KOBOLD].count,
|
||||||
data: JSON.stringify({
|
data: JSON.stringify({
|
||||||
text: str,
|
text: str,
|
||||||
url: api_server,
|
url: api_server,
|
||||||
|
@ -468,18 +505,17 @@ function getTextgenAPITokenizationParams(str) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Count tokens using the AI provider's API.
|
* Count tokens using the AI provider's API.
|
||||||
* @param {string} endpoint API endpoint.
|
|
||||||
* @param {string} str String to tokenize.
|
* @param {string} str String to tokenize.
|
||||||
* @param {number} padding Number of padding tokens.
|
* @param {number} padding Number of padding tokens.
|
||||||
* @returns {number} Token count with padding.
|
* @returns {number} Token count with padding.
|
||||||
*/
|
*/
|
||||||
function countTokensFromTextgenAPI(endpoint, str, padding) {
|
function countTokensFromTextgenAPI(str, padding) {
|
||||||
let tokenCount = 0;
|
let tokenCount = 0;
|
||||||
|
|
||||||
jQuery.ajax({
|
jQuery.ajax({
|
||||||
async: false,
|
async: false,
|
||||||
type: 'POST',
|
type: 'POST',
|
||||||
url: endpoint,
|
url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].count,
|
||||||
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
|
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
|
||||||
dataType: 'json',
|
dataType: 'json',
|
||||||
contentType: 'application/json',
|
contentType: 'application/json',
|
||||||
|
@ -515,14 +551,9 @@ function apiFailureTokenCount(str) {
|
||||||
* Calls the underlying tokenizer model to encode a string to tokens.
|
* Calls the underlying tokenizer model to encode a string to tokens.
|
||||||
* @param {string} endpoint API endpoint.
|
* @param {string} endpoint API endpoint.
|
||||||
* @param {string} str String to tokenize.
|
* @param {string} str String to tokenize.
|
||||||
* @param {string} model Tokenizer model.
|
|
||||||
* @returns {number[]} Array of token ids.
|
* @returns {number[]} Array of token ids.
|
||||||
*/
|
*/
|
||||||
function getTextTokensFromServer(endpoint, str, model = '') {
|
function getTextTokensFromServer(endpoint, str) {
|
||||||
if (model) {
|
|
||||||
endpoint += `?model=${model}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
let ids = [];
|
let ids = [];
|
||||||
jQuery.ajax({
|
jQuery.ajax({
|
||||||
async: false,
|
async: false,
|
||||||
|
@ -545,16 +576,15 @@ function getTextTokensFromServer(endpoint, str, model = '') {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calls the AI provider's tokenize API to encode a string to tokens.
|
* Calls the AI provider's tokenize API to encode a string to tokens.
|
||||||
* @param {string} endpoint API endpoint.
|
|
||||||
* @param {string} str String to tokenize.
|
* @param {string} str String to tokenize.
|
||||||
* @returns {number[]} Array of token ids.
|
* @returns {number[]} Array of token ids.
|
||||||
*/
|
*/
|
||||||
function getTextTokensFromTextgenAPI(endpoint, str) {
|
function getTextTokensFromTextgenAPI(str) {
|
||||||
let ids = [];
|
let ids = [];
|
||||||
jQuery.ajax({
|
jQuery.ajax({
|
||||||
async: false,
|
async: false,
|
||||||
type: 'POST',
|
type: 'POST',
|
||||||
url: endpoint,
|
url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].encode,
|
||||||
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
|
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
|
||||||
dataType: 'json',
|
dataType: 'json',
|
||||||
contentType: 'application/json',
|
contentType: 'application/json',
|
||||||
|
@ -570,11 +600,7 @@ function getTextTokensFromTextgenAPI(endpoint, str) {
|
||||||
* @param {string} endpoint API endpoint.
|
* @param {string} endpoint API endpoint.
|
||||||
* @param {number[]} ids Array of token ids
|
* @param {number[]} ids Array of token ids
|
||||||
*/
|
*/
|
||||||
function decodeTextTokensFromServer(endpoint, ids, model = '') {
|
function decodeTextTokensFromServer(endpoint, ids) {
|
||||||
if (model) {
|
|
||||||
endpoint += `?model=${model}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
let text = '';
|
let text = '';
|
||||||
jQuery.ajax({
|
jQuery.ajax({
|
||||||
async: false,
|
async: false,
|
||||||
|
@ -598,27 +624,24 @@ function decodeTextTokensFromServer(endpoint, ids, model = '') {
|
||||||
*/
|
*/
|
||||||
export function getTextTokens(tokenizerType, str) {
|
export function getTextTokens(tokenizerType, str) {
|
||||||
switch (tokenizerType) {
|
switch (tokenizerType) {
|
||||||
case tokenizers.GPT2:
|
|
||||||
return getTextTokensFromServer('/api/tokenizers/gpt2/encode', str);
|
|
||||||
case tokenizers.LLAMA:
|
|
||||||
return getTextTokensFromServer('/api/tokenizers/llama/encode', str);
|
|
||||||
case tokenizers.NERD:
|
|
||||||
return getTextTokensFromServer('/api/tokenizers/nerdstash/encode', str);
|
|
||||||
case tokenizers.NERD2:
|
|
||||||
return getTextTokensFromServer('/api/tokenizers/nerdstash_v2/encode', str);
|
|
||||||
case tokenizers.MISTRAL:
|
|
||||||
return getTextTokensFromServer('/api/tokenizers/mistral/encode', str);
|
|
||||||
case tokenizers.YI:
|
|
||||||
return getTextTokensFromServer('/api/tokenizers/yi/encode', str);
|
|
||||||
case tokenizers.OPENAI: {
|
|
||||||
const model = getTokenizerModel();
|
|
||||||
return getTextTokensFromServer('/api/tokenizers/openai/encode', str, model);
|
|
||||||
}
|
|
||||||
case tokenizers.API_TEXTGENERATIONWEBUI:
|
case tokenizers.API_TEXTGENERATIONWEBUI:
|
||||||
return getTextTokensFromTextgenAPI('/api/tokenizers/textgenerationwebui/encode', str);
|
return getTextTokensFromTextgenAPI(str);
|
||||||
default:
|
default: {
|
||||||
console.warn('Calling getTextTokens with unsupported tokenizer type', tokenizerType);
|
const tokenizerEndpoints = TOKENIZER_URLS[tokenizerType];
|
||||||
return [];
|
if (!tokenizerEndpoints) {
|
||||||
|
console.warn('Unknown tokenizer type', tokenizerType);
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
let endpointUrl = tokenizerEndpoints.encode;
|
||||||
|
if (!endpointUrl) {
|
||||||
|
console.warn('This tokenizer type does not support encoding', tokenizerType);
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
if (tokenizerType === tokenizers.OPENAI) {
|
||||||
|
endpointUrl += `?model=${getTokenizerModel()}`;
|
||||||
|
}
|
||||||
|
return getTextTokensFromServer(endpointUrl, str);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -628,27 +651,20 @@ export function getTextTokens(tokenizerType, str) {
|
||||||
* @param {number[]} ids Array of token ids
|
* @param {number[]} ids Array of token ids
|
||||||
*/
|
*/
|
||||||
export function decodeTextTokens(tokenizerType, ids) {
|
export function decodeTextTokens(tokenizerType, ids) {
|
||||||
switch (tokenizerType) {
|
const tokenizerEndpoints = TOKENIZER_URLS[tokenizerType];
|
||||||
case tokenizers.GPT2:
|
if (!tokenizerEndpoints) {
|
||||||
return decodeTextTokensFromServer('/api/tokenizers/gpt2/decode', ids);
|
console.warn('Unknown tokenizer type', tokenizerType);
|
||||||
case tokenizers.LLAMA:
|
return [];
|
||||||
return decodeTextTokensFromServer('/api/tokenizers/llama/decode', ids);
|
|
||||||
case tokenizers.NERD:
|
|
||||||
return decodeTextTokensFromServer('/api/tokenizers/nerdstash/decode', ids);
|
|
||||||
case tokenizers.NERD2:
|
|
||||||
return decodeTextTokensFromServer('/api/tokenizers/nerdstash_v2/decode', ids);
|
|
||||||
case tokenizers.MISTRAL:
|
|
||||||
return decodeTextTokensFromServer('/api/tokenizers/mistral/decode', ids);
|
|
||||||
case tokenizers.YI:
|
|
||||||
return decodeTextTokensFromServer('/api/tokenizers/yi/decode', ids);
|
|
||||||
case tokenizers.OPENAI: {
|
|
||||||
const model = getTokenizerModel();
|
|
||||||
return decodeTextTokensFromServer('/api/tokenizers/openai/decode', ids, model);
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
console.warn('Calling decodeTextTokens with unsupported tokenizer type', tokenizerType);
|
|
||||||
return '';
|
|
||||||
}
|
}
|
||||||
|
let endpointUrl = tokenizerEndpoints.decode;
|
||||||
|
if (!endpointUrl) {
|
||||||
|
console.warn('This tokenizer type does not support decoding', tokenizerType);
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
if (tokenizerType === tokenizers.OPENAI) {
|
||||||
|
endpointUrl += `?model=${getTokenizerModel()}`;
|
||||||
|
}
|
||||||
|
return decodeTextTokensFromServer(endpointUrl, ids);
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function initTokenizers() {
|
export async function initTokenizers() {
|
||||||
|
|
Loading…
Reference in New Issue