Merge pull request #1503 from valadaptive/tokenizers-cleanup
Tokenizers cleanup
This commit is contained in:
commit
ae01e7419f
|
@ -874,7 +874,7 @@ async function getStatusKobold() {
|
||||||
|
|
||||||
const url = '/getstatus';
|
const url = '/getstatus';
|
||||||
|
|
||||||
let endpoint = getAPIServerUrl();
|
let endpoint = api_server;
|
||||||
|
|
||||||
if (!endpoint) {
|
if (!endpoint) {
|
||||||
console.warn('No endpoint for status check');
|
console.warn('No endpoint for status check');
|
||||||
|
@ -922,7 +922,9 @@ async function getStatusKobold() {
|
||||||
async function getStatusTextgen() {
|
async function getStatusTextgen() {
|
||||||
const url = '/api/textgenerationwebui/status';
|
const url = '/api/textgenerationwebui/status';
|
||||||
|
|
||||||
let endpoint = getAPIServerUrl();
|
let endpoint = textgen_settings.type === MANCER ?
|
||||||
|
MANCER_SERVER :
|
||||||
|
api_server_textgenerationwebui;
|
||||||
|
|
||||||
if (!endpoint) {
|
if (!endpoint) {
|
||||||
console.warn('No endpoint for status check');
|
console.warn('No endpoint for status check');
|
||||||
|
@ -1002,23 +1004,6 @@ export function resultCheckStatus() {
|
||||||
stopStatusLoading();
|
stopStatusLoading();
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(valadaptive): remove the usage of this function in the tokenizers code, then remove the function entirely
|
|
||||||
export function getAPIServerUrl() {
|
|
||||||
if (main_api == 'textgenerationwebui') {
|
|
||||||
if (textgen_settings.type === MANCER) {
|
|
||||||
return MANCER_SERVER;
|
|
||||||
}
|
|
||||||
|
|
||||||
return api_server_textgenerationwebui;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (main_api == 'kobold') {
|
|
||||||
return api_server;
|
|
||||||
}
|
|
||||||
|
|
||||||
return '';
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function selectCharacterById(id) {
|
export async function selectCharacterById(id) {
|
||||||
if (characters[id] == undefined) {
|
if (characters[id] == undefined) {
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import { characters, getAPIServerUrl, main_api, nai_settings, online_status, this_chid } from '../script.js';
|
import { characters, main_api, api_server, api_server_textgenerationwebui, nai_settings, online_status, this_chid } from '../script.js';
|
||||||
import { power_user, registerDebugFunction } from './power-user.js';
|
import { power_user, registerDebugFunction } from './power-user.js';
|
||||||
import { chat_completion_sources, model_list, oai_settings } from './openai.js';
|
import { chat_completion_sources, model_list, oai_settings } from './openai.js';
|
||||||
import { groups, selected_group } from './group-chats.js';
|
import { groups, selected_group } from './group-chats.js';
|
||||||
|
@ -18,9 +18,11 @@ export const tokenizers = {
|
||||||
LLAMA: 3,
|
LLAMA: 3,
|
||||||
NERD: 4,
|
NERD: 4,
|
||||||
NERD2: 5,
|
NERD2: 5,
|
||||||
API: 6,
|
API_CURRENT: 6,
|
||||||
MISTRAL: 7,
|
MISTRAL: 7,
|
||||||
YI: 8,
|
YI: 8,
|
||||||
|
API_TEXTGENERATIONWEBUI: 9,
|
||||||
|
API_KOBOLD: 10,
|
||||||
BEST_MATCH: 99,
|
BEST_MATCH: 99,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -33,6 +35,51 @@ export const SENTENCEPIECE_TOKENIZERS = [
|
||||||
//tokenizers.NERD2,
|
//tokenizers.NERD2,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
const TOKENIZER_URLS = {
|
||||||
|
[tokenizers.GPT2]: {
|
||||||
|
encode: '/api/tokenizers/gpt2/encode',
|
||||||
|
decode: '/api/tokenizers/gpt2/decode',
|
||||||
|
count: '/api/tokenizers/gpt2/encode',
|
||||||
|
},
|
||||||
|
[tokenizers.OPENAI]: {
|
||||||
|
encode: '/api/tokenizers/openai/encode',
|
||||||
|
decode: '/api/tokenizers/openai/decode',
|
||||||
|
count: '/api/tokenizers/openai/encode',
|
||||||
|
},
|
||||||
|
[tokenizers.LLAMA]: {
|
||||||
|
encode: '/api/tokenizers/llama/encode',
|
||||||
|
decode: '/api/tokenizers/llama/decode',
|
||||||
|
count: '/api/tokenizers/llama/encode',
|
||||||
|
},
|
||||||
|
[tokenizers.NERD]: {
|
||||||
|
encode: '/api/tokenizers/nerdstash/encode',
|
||||||
|
decode: '/api/tokenizers/nerdstash/decode',
|
||||||
|
count: '/api/tokenizers/nerdstash/encode',
|
||||||
|
},
|
||||||
|
[tokenizers.NERD2]: {
|
||||||
|
encode: '/api/tokenizers/nerdstash_v2/encode',
|
||||||
|
decode: '/api/tokenizers/nerdstash_v2/decode',
|
||||||
|
count: '/api/tokenizers/nerdstash_v2/encode',
|
||||||
|
},
|
||||||
|
[tokenizers.API_KOBOLD]: {
|
||||||
|
count: '/api/tokenizers/remote/kobold/count',
|
||||||
|
},
|
||||||
|
[tokenizers.MISTRAL]: {
|
||||||
|
encode: '/api/tokenizers/mistral/encode',
|
||||||
|
decode: '/api/tokenizers/mistral/decode',
|
||||||
|
count: '/api/tokenizers/mistral/encode',
|
||||||
|
},
|
||||||
|
[tokenizers.YI]: {
|
||||||
|
encode: '/api/tokenizers/yi/encode',
|
||||||
|
decode: '/api/tokenizers/yi/decode',
|
||||||
|
count: '/api/tokenizers/yi/encode',
|
||||||
|
},
|
||||||
|
[tokenizers.API_TEXTGENERATIONWEBUI]: {
|
||||||
|
encode: '/api/tokenizers/remote/textgenerationwebui/encode',
|
||||||
|
count: '/api/tokenizers/remote/textgenerationwebui/encode',
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
const objectStore = new localforage.createInstance({ name: 'SillyTavern_ChatCompletions' });
|
const objectStore = new localforage.createInstance({ name: 'SillyTavern_ChatCompletions' });
|
||||||
|
|
||||||
let tokenCache = {};
|
let tokenCache = {};
|
||||||
|
@ -92,7 +139,18 @@ export function getFriendlyTokenizerName(forApi) {
|
||||||
|
|
||||||
if (forApi !== 'openai' && tokenizerId === tokenizers.BEST_MATCH) {
|
if (forApi !== 'openai' && tokenizerId === tokenizers.BEST_MATCH) {
|
||||||
tokenizerId = getTokenizerBestMatch(forApi);
|
tokenizerId = getTokenizerBestMatch(forApi);
|
||||||
tokenizerName = $(`#tokenizer option[value="${tokenizerId}"]`).text();
|
|
||||||
|
switch (tokenizerId) {
|
||||||
|
case tokenizers.API_KOBOLD:
|
||||||
|
tokenizerName = 'API (KoboldAI Classic)';
|
||||||
|
break;
|
||||||
|
case tokenizers.API_TEXTGENERATIONWEBUI:
|
||||||
|
tokenizerName = 'API (Text Completion)';
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
tokenizerName = $(`#tokenizer option[value="${tokenizerId}"]`).text();
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenizerName = forApi == 'openai'
|
tokenizerName = forApi == 'openai'
|
||||||
|
@ -135,11 +193,11 @@ export function getTokenizerBestMatch(forApi) {
|
||||||
|
|
||||||
if (!hasTokenizerError && isConnected) {
|
if (!hasTokenizerError && isConnected) {
|
||||||
if (forApi === 'kobold' && kai_flags.can_use_tokenization) {
|
if (forApi === 'kobold' && kai_flags.can_use_tokenization) {
|
||||||
return tokenizers.API;
|
return tokenizers.API_KOBOLD;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (forApi === 'textgenerationwebui' && isTokenizerSupported) {
|
if (forApi === 'textgenerationwebui' && isTokenizerSupported) {
|
||||||
return tokenizers.API;
|
return tokenizers.API_TEXTGENERATIONWEBUI;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -149,34 +207,42 @@ export function getTokenizerBestMatch(forApi) {
|
||||||
return tokenizers.NONE;
|
return tokenizers.NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the current remote tokenizer API based on the current text generation API.
|
||||||
|
function currentRemoteTokenizerAPI() {
|
||||||
|
switch (main_api) {
|
||||||
|
case 'kobold':
|
||||||
|
return tokenizers.API_KOBOLD;
|
||||||
|
case 'textgenerationwebui':
|
||||||
|
return tokenizers.API_TEXTGENERATIONWEBUI;
|
||||||
|
default:
|
||||||
|
return tokenizers.NONE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calls the underlying tokenizer model to the token count for a string.
|
* Calls the underlying tokenizer model to the token count for a string.
|
||||||
* @param {number} type Tokenizer type.
|
* @param {number} type Tokenizer type.
|
||||||
* @param {string} str String to tokenize.
|
* @param {string} str String to tokenize.
|
||||||
* @param {number} padding Number of padding tokens.
|
|
||||||
* @returns {number} Token count.
|
* @returns {number} Token count.
|
||||||
*/
|
*/
|
||||||
function callTokenizer(type, str, padding) {
|
function callTokenizer(type, str) {
|
||||||
|
if (type === tokenizers.NONE) return guesstimate(str);
|
||||||
|
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case tokenizers.NONE:
|
case tokenizers.API_CURRENT:
|
||||||
return guesstimate(str) + padding;
|
return callTokenizer(currentRemoteTokenizerAPI(), str);
|
||||||
case tokenizers.GPT2:
|
case tokenizers.API_KOBOLD:
|
||||||
return countTokensRemote('/api/tokenizers/gpt2/encode', str, padding);
|
return countTokensFromKoboldAPI(str);
|
||||||
case tokenizers.LLAMA:
|
case tokenizers.API_TEXTGENERATIONWEBUI:
|
||||||
return countTokensRemote('/api/tokenizers/llama/encode', str, padding);
|
return countTokensFromTextgenAPI(str);
|
||||||
case tokenizers.NERD:
|
default: {
|
||||||
return countTokensRemote('/api/tokenizers/nerdstash/encode', str, padding);
|
const endpointUrl = TOKENIZER_URLS[type]?.count;
|
||||||
case tokenizers.NERD2:
|
if (!endpointUrl) {
|
||||||
return countTokensRemote('/api/tokenizers/nerdstash_v2/encode', str, padding);
|
console.warn('Unknown tokenizer type', type);
|
||||||
case tokenizers.MISTRAL:
|
return apiFailureTokenCount(str);
|
||||||
return countTokensRemote('/api/tokenizers/mistral/encode', str, padding);
|
}
|
||||||
case tokenizers.YI:
|
return countTokensFromServer(endpointUrl, str);
|
||||||
return countTokensRemote('/api/tokenizers/yi/encode', str, padding);
|
}
|
||||||
case tokenizers.API:
|
|
||||||
return countTokensRemote('/tokenize_via_api', str, padding);
|
|
||||||
default:
|
|
||||||
console.warn('Unknown tokenizer type', type);
|
|
||||||
return callTokenizer(tokenizers.NONE, str, padding);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -219,7 +285,7 @@ export function getTokenCount(str, padding = undefined) {
|
||||||
return cacheObject[cacheKey];
|
return cacheObject[cacheKey];
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = callTokenizer(tokenizerType, str, padding);
|
const result = callTokenizer(tokenizerType, str) + padding;
|
||||||
|
|
||||||
if (isNaN(result)) {
|
if (isNaN(result)) {
|
||||||
console.warn('Token count calculation returned NaN');
|
console.warn('Token count calculation returned NaN');
|
||||||
|
@ -391,76 +457,131 @@ function getTokenCacheObject() {
|
||||||
return tokenCache[String(chatId)];
|
return tokenCache[String(chatId)];
|
||||||
}
|
}
|
||||||
|
|
||||||
function getRemoteTokenizationParams(str) {
|
|
||||||
return {
|
|
||||||
text: str,
|
|
||||||
main_api,
|
|
||||||
api_type: textgen_settings.type,
|
|
||||||
url: getAPIServerUrl(),
|
|
||||||
legacy_api: main_api === 'textgenerationwebui' &&
|
|
||||||
textgen_settings.legacy_api &&
|
|
||||||
textgen_settings.type !== MANCER,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Counts token using the remote server API.
|
* Count tokens using the server API.
|
||||||
* @param {string} endpoint API endpoint.
|
* @param {string} endpoint API endpoint.
|
||||||
* @param {string} str String to tokenize.
|
* @param {string} str String to tokenize.
|
||||||
* @param {number} padding Number of padding tokens.
|
* @returns {number} Token count.
|
||||||
* @returns {number} Token count with padding.
|
|
||||||
*/
|
*/
|
||||||
function countTokensRemote(endpoint, str, padding) {
|
function countTokensFromServer(endpoint, str) {
|
||||||
let tokenCount = 0;
|
let tokenCount = 0;
|
||||||
|
|
||||||
jQuery.ajax({
|
jQuery.ajax({
|
||||||
async: false,
|
async: false,
|
||||||
type: 'POST',
|
type: 'POST',
|
||||||
url: endpoint,
|
url: endpoint,
|
||||||
data: JSON.stringify(getRemoteTokenizationParams(str)),
|
data: JSON.stringify({ text: str }),
|
||||||
dataType: 'json',
|
dataType: 'json',
|
||||||
contentType: 'application/json',
|
contentType: 'application/json',
|
||||||
success: function (data) {
|
success: function (data) {
|
||||||
if (typeof data.count === 'number') {
|
if (typeof data.count === 'number') {
|
||||||
tokenCount = data.count;
|
tokenCount = data.count;
|
||||||
} else {
|
} else {
|
||||||
tokenCount = guesstimate(str);
|
tokenCount = apiFailureTokenCount(str);
|
||||||
console.error('Error counting tokens');
|
|
||||||
|
|
||||||
if (!sessionStorage.getItem(TOKENIZER_WARNING_KEY)) {
|
|
||||||
toastr.warning(
|
|
||||||
'Your selected API doesn\'t support the tokenization endpoint. Using estimated counts.',
|
|
||||||
'Error counting tokens',
|
|
||||||
{ timeOut: 10000, preventDuplicates: true },
|
|
||||||
);
|
|
||||||
|
|
||||||
sessionStorage.setItem(TOKENIZER_WARNING_KEY, String(true));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
return tokenCount + padding;
|
return tokenCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Count tokens using the AI provider's API.
|
||||||
|
* @param {string} str String to tokenize.
|
||||||
|
* @returns {number} Token count.
|
||||||
|
*/
|
||||||
|
function countTokensFromKoboldAPI(str) {
|
||||||
|
let tokenCount = 0;
|
||||||
|
|
||||||
|
jQuery.ajax({
|
||||||
|
async: false,
|
||||||
|
type: 'POST',
|
||||||
|
url: TOKENIZER_URLS[tokenizers.API_KOBOLD].count,
|
||||||
|
data: JSON.stringify({
|
||||||
|
text: str,
|
||||||
|
url: api_server,
|
||||||
|
}),
|
||||||
|
dataType: 'json',
|
||||||
|
contentType: 'application/json',
|
||||||
|
success: function (data) {
|
||||||
|
if (typeof data.count === 'number') {
|
||||||
|
tokenCount = data.count;
|
||||||
|
} else {
|
||||||
|
tokenCount = apiFailureTokenCount(str);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return tokenCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
function getTextgenAPITokenizationParams(str) {
|
||||||
|
return {
|
||||||
|
text: str,
|
||||||
|
api_type: textgen_settings.type,
|
||||||
|
url: api_server_textgenerationwebui,
|
||||||
|
legacy_api:
|
||||||
|
textgen_settings.legacy_api &&
|
||||||
|
textgen_settings.type !== MANCER,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Count tokens using the AI provider's API.
|
||||||
|
* @param {string} str String to tokenize.
|
||||||
|
* @returns {number} Token count.
|
||||||
|
*/
|
||||||
|
function countTokensFromTextgenAPI(str) {
|
||||||
|
let tokenCount = 0;
|
||||||
|
|
||||||
|
jQuery.ajax({
|
||||||
|
async: false,
|
||||||
|
type: 'POST',
|
||||||
|
url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].count,
|
||||||
|
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
|
||||||
|
dataType: 'json',
|
||||||
|
contentType: 'application/json',
|
||||||
|
success: function (data) {
|
||||||
|
if (typeof data.count === 'number') {
|
||||||
|
tokenCount = data.count;
|
||||||
|
} else {
|
||||||
|
tokenCount = apiFailureTokenCount(str);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return tokenCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
function apiFailureTokenCount(str) {
|
||||||
|
console.error('Error counting tokens');
|
||||||
|
|
||||||
|
if (!sessionStorage.getItem(TOKENIZER_WARNING_KEY)) {
|
||||||
|
toastr.warning(
|
||||||
|
'Your selected API doesn\'t support the tokenization endpoint. Using estimated counts.',
|
||||||
|
'Error counting tokens',
|
||||||
|
{ timeOut: 10000, preventDuplicates: true },
|
||||||
|
);
|
||||||
|
|
||||||
|
sessionStorage.setItem(TOKENIZER_WARNING_KEY, String(true));
|
||||||
|
}
|
||||||
|
|
||||||
|
return guesstimate(str);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calls the underlying tokenizer model to encode a string to tokens.
|
* Calls the underlying tokenizer model to encode a string to tokens.
|
||||||
* @param {string} endpoint API endpoint.
|
* @param {string} endpoint API endpoint.
|
||||||
* @param {string} str String to tokenize.
|
* @param {string} str String to tokenize.
|
||||||
* @param {string} model Tokenizer model.
|
|
||||||
* @returns {number[]} Array of token ids.
|
* @returns {number[]} Array of token ids.
|
||||||
*/
|
*/
|
||||||
function getTextTokensRemote(endpoint, str, model = '') {
|
function getTextTokensFromServer(endpoint, str) {
|
||||||
if (model) {
|
|
||||||
endpoint += `?model=${model}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
let ids = [];
|
let ids = [];
|
||||||
jQuery.ajax({
|
jQuery.ajax({
|
||||||
async: false,
|
async: false,
|
||||||
type: 'POST',
|
type: 'POST',
|
||||||
url: endpoint,
|
url: endpoint,
|
||||||
data: JSON.stringify(getRemoteTokenizationParams(str)),
|
data: JSON.stringify({ text: str }),
|
||||||
dataType: 'json',
|
dataType: 'json',
|
||||||
contentType: 'application/json',
|
contentType: 'application/json',
|
||||||
success: function (data) {
|
success: function (data) {
|
||||||
|
@ -475,16 +596,33 @@ function getTextTokensRemote(endpoint, str, model = '') {
|
||||||
return ids;
|
return ids;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calls the AI provider's tokenize API to encode a string to tokens.
|
||||||
|
* @param {string} str String to tokenize.
|
||||||
|
* @returns {number[]} Array of token ids.
|
||||||
|
*/
|
||||||
|
function getTextTokensFromTextgenAPI(str) {
|
||||||
|
let ids = [];
|
||||||
|
jQuery.ajax({
|
||||||
|
async: false,
|
||||||
|
type: 'POST',
|
||||||
|
url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].encode,
|
||||||
|
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
|
||||||
|
dataType: 'json',
|
||||||
|
contentType: 'application/json',
|
||||||
|
success: function (data) {
|
||||||
|
ids = data.ids;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
return ids;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calls the underlying tokenizer model to decode token ids to text.
|
* Calls the underlying tokenizer model to decode token ids to text.
|
||||||
* @param {string} endpoint API endpoint.
|
* @param {string} endpoint API endpoint.
|
||||||
* @param {number[]} ids Array of token ids
|
* @param {number[]} ids Array of token ids
|
||||||
*/
|
*/
|
||||||
function decodeTextTokensRemote(endpoint, ids, model = '') {
|
function decodeTextTokensFromServer(endpoint, ids) {
|
||||||
if (model) {
|
|
||||||
endpoint += `?model=${model}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
let text = '';
|
let text = '';
|
||||||
jQuery.ajax({
|
jQuery.ajax({
|
||||||
async: false,
|
async: false,
|
||||||
|
@ -501,64 +639,62 @@ function decodeTextTokensRemote(endpoint, ids, model = '') {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Encodes a string to tokens using the remote server API.
|
* Encodes a string to tokens using the server API.
|
||||||
* @param {number} tokenizerType Tokenizer type.
|
* @param {number} tokenizerType Tokenizer type.
|
||||||
* @param {string} str String to tokenize.
|
* @param {string} str String to tokenize.
|
||||||
* @returns {number[]} Array of token ids.
|
* @returns {number[]} Array of token ids.
|
||||||
*/
|
*/
|
||||||
export function getTextTokens(tokenizerType, str) {
|
export function getTextTokens(tokenizerType, str) {
|
||||||
switch (tokenizerType) {
|
switch (tokenizerType) {
|
||||||
case tokenizers.GPT2:
|
case tokenizers.API_CURRENT:
|
||||||
return getTextTokensRemote('/api/tokenizers/gpt2/encode', str);
|
return getTextTokens(currentRemoteTokenizerAPI(), str);
|
||||||
case tokenizers.LLAMA:
|
case tokenizers.API_TEXTGENERATIONWEBUI:
|
||||||
return getTextTokensRemote('/api/tokenizers/llama/encode', str);
|
return getTextTokensFromTextgenAPI(str);
|
||||||
case tokenizers.NERD:
|
default: {
|
||||||
return getTextTokensRemote('/api/tokenizers/nerdstash/encode', str);
|
const tokenizerEndpoints = TOKENIZER_URLS[tokenizerType];
|
||||||
case tokenizers.NERD2:
|
if (!tokenizerEndpoints) {
|
||||||
return getTextTokensRemote('/api/tokenizers/nerdstash_v2/encode', str);
|
apiFailureTokenCount(str);
|
||||||
case tokenizers.MISTRAL:
|
console.warn('Unknown tokenizer type', tokenizerType);
|
||||||
return getTextTokensRemote('/api/tokenizers/mistral/encode', str);
|
return [];
|
||||||
case tokenizers.YI:
|
}
|
||||||
return getTextTokensRemote('/api/tokenizers/yi/encode', str);
|
let endpointUrl = tokenizerEndpoints.encode;
|
||||||
case tokenizers.OPENAI: {
|
if (!endpointUrl) {
|
||||||
const model = getTokenizerModel();
|
apiFailureTokenCount(str);
|
||||||
return getTextTokensRemote('/api/tokenizers/openai/encode', str, model);
|
console.warn('This tokenizer type does not support encoding', tokenizerType);
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
if (tokenizerType === tokenizers.OPENAI) {
|
||||||
|
endpointUrl += `?model=${getTokenizerModel()}`;
|
||||||
|
}
|
||||||
|
return getTextTokensFromServer(endpointUrl, str);
|
||||||
}
|
}
|
||||||
case tokenizers.API:
|
|
||||||
return getTextTokensRemote('/tokenize_via_api', str);
|
|
||||||
default:
|
|
||||||
console.warn('Calling getTextTokens with unsupported tokenizer type', tokenizerType);
|
|
||||||
return [];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Decodes token ids to text using the remote server API.
|
* Decodes token ids to text using the server API.
|
||||||
* @param {number} tokenizerType Tokenizer type.
|
* @param {number} tokenizerType Tokenizer type.
|
||||||
* @param {number[]} ids Array of token ids
|
* @param {number[]} ids Array of token ids
|
||||||
*/
|
*/
|
||||||
export function decodeTextTokens(tokenizerType, ids) {
|
export function decodeTextTokens(tokenizerType, ids) {
|
||||||
switch (tokenizerType) {
|
// Currently, neither remote API can decode, but this may change in the future. Put this guard here to be safe
|
||||||
case tokenizers.GPT2:
|
if (tokenizerType === tokenizers.API_CURRENT) {
|
||||||
return decodeTextTokensRemote('/api/tokenizers/gpt2/decode', ids);
|
return decodeTextTokens(tokenizers.NONE, ids);
|
||||||
case tokenizers.LLAMA:
|
|
||||||
return decodeTextTokensRemote('/api/tokenizers/llama/decode', ids);
|
|
||||||
case tokenizers.NERD:
|
|
||||||
return decodeTextTokensRemote('/api/tokenizers/nerdstash/decode', ids);
|
|
||||||
case tokenizers.NERD2:
|
|
||||||
return decodeTextTokensRemote('/api/tokenizers/nerdstash_v2/decode', ids);
|
|
||||||
case tokenizers.MISTRAL:
|
|
||||||
return decodeTextTokensRemote('/api/tokenizers/mistral/decode', ids);
|
|
||||||
case tokenizers.YI:
|
|
||||||
return decodeTextTokensRemote('/api/tokenizers/yi/decode', ids);
|
|
||||||
case tokenizers.OPENAI: {
|
|
||||||
const model = getTokenizerModel();
|
|
||||||
return decodeTextTokensRemote('/api/tokenizers/openai/decode', ids, model);
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
console.warn('Calling decodeTextTokens with unsupported tokenizer type', tokenizerType);
|
|
||||||
return '';
|
|
||||||
}
|
}
|
||||||
|
const tokenizerEndpoints = TOKENIZER_URLS[tokenizerType];
|
||||||
|
if (!tokenizerEndpoints) {
|
||||||
|
console.warn('Unknown tokenizer type', tokenizerType);
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
let endpointUrl = tokenizerEndpoints.decode;
|
||||||
|
if (!endpointUrl) {
|
||||||
|
console.warn('This tokenizer type does not support decoding', tokenizerType);
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
if (tokenizerType === tokenizers.OPENAI) {
|
||||||
|
endpointUrl += `?model=${getTokenizerModel()}`;
|
||||||
|
}
|
||||||
|
return decodeTextTokensFromServer(endpointUrl, ids);
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function initTokenizers() {
|
export async function initTokenizers() {
|
||||||
|
|
152
server.js
152
server.js
|
@ -49,6 +49,7 @@ const { delay, getVersion, getConfigValue, color, uuidv4, tryParse, clientRelati
|
||||||
const { ensureThumbnailCache } = require('./src/endpoints/thumbnails');
|
const { ensureThumbnailCache } = require('./src/endpoints/thumbnails');
|
||||||
const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS, getSentencepiceTokenizer, sentencepieceTokenizers } = require('./src/endpoints/tokenizers');
|
const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS, getSentencepiceTokenizer, sentencepieceTokenizers } = require('./src/endpoints/tokenizers');
|
||||||
const { convertClaudePrompt } = require('./src/chat-completion');
|
const { convertClaudePrompt } = require('./src/chat-completion');
|
||||||
|
const { getOverrideHeaders, setAdditionalHeaders } = require('./src/additional-headers');
|
||||||
|
|
||||||
// Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0.
|
// Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0.
|
||||||
// https://github.com/nodejs/node/issues/47822#issuecomment-1564708870
|
// https://github.com/nodejs/node/issues/47822#issuecomment-1564708870
|
||||||
|
@ -119,70 +120,6 @@ const listen = getConfigValue('listen', false);
|
||||||
const API_OPENAI = 'https://api.openai.com/v1';
|
const API_OPENAI = 'https://api.openai.com/v1';
|
||||||
const API_CLAUDE = 'https://api.anthropic.com/v1';
|
const API_CLAUDE = 'https://api.anthropic.com/v1';
|
||||||
|
|
||||||
function getMancerHeaders() {
|
|
||||||
const apiKey = readSecret(SECRET_KEYS.MANCER);
|
|
||||||
|
|
||||||
return apiKey ? ({
|
|
||||||
'X-API-KEY': apiKey,
|
|
||||||
'Authorization': `Bearer ${apiKey}`,
|
|
||||||
}) : {};
|
|
||||||
}
|
|
||||||
|
|
||||||
function getAphroditeHeaders() {
|
|
||||||
const apiKey = readSecret(SECRET_KEYS.APHRODITE);
|
|
||||||
|
|
||||||
return apiKey ? ({
|
|
||||||
'X-API-KEY': apiKey,
|
|
||||||
'Authorization': `Bearer ${apiKey}`,
|
|
||||||
}) : {};
|
|
||||||
}
|
|
||||||
|
|
||||||
function getTabbyHeaders() {
|
|
||||||
const apiKey = readSecret(SECRET_KEYS.TABBY);
|
|
||||||
|
|
||||||
return apiKey ? ({
|
|
||||||
'x-api-key': apiKey,
|
|
||||||
'Authorization': `Bearer ${apiKey}`,
|
|
||||||
}) : {};
|
|
||||||
}
|
|
||||||
|
|
||||||
function getOverrideHeaders(urlHost) {
|
|
||||||
const requestOverrides = getConfigValue('requestOverrides', []);
|
|
||||||
const overrideHeaders = requestOverrides?.find((e) => e.hosts?.includes(urlHost))?.headers;
|
|
||||||
if (overrideHeaders && urlHost) {
|
|
||||||
return overrideHeaders;
|
|
||||||
} else {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Sets additional headers for the request.
|
|
||||||
* @param {object} request Original request body
|
|
||||||
* @param {object} args New request arguments
|
|
||||||
* @param {string|null} server API server for new request
|
|
||||||
*/
|
|
||||||
function setAdditionalHeaders(request, args, server) {
|
|
||||||
let headers;
|
|
||||||
|
|
||||||
switch (request.body.api_type) {
|
|
||||||
case TEXTGEN_TYPES.MANCER:
|
|
||||||
headers = getMancerHeaders();
|
|
||||||
break;
|
|
||||||
case TEXTGEN_TYPES.APHRODITE:
|
|
||||||
headers = getAphroditeHeaders();
|
|
||||||
break;
|
|
||||||
case TEXTGEN_TYPES.TABBY:
|
|
||||||
headers = getTabbyHeaders();
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
headers = server ? getOverrideHeaders((new URL(server))?.host) : {};
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
Object.assign(args.headers, headers);
|
|
||||||
}
|
|
||||||
|
|
||||||
const SETTINGS_FILE = './public/settings.json';
|
const SETTINGS_FILE = './public/settings.json';
|
||||||
const { DIRECTORIES, UPLOADS_PATH, PALM_SAFETY, TEXTGEN_TYPES, CHAT_COMPLETION_SOURCES, AVATAR_WIDTH, AVATAR_HEIGHT } = require('./src/constants');
|
const { DIRECTORIES, UPLOADS_PATH, PALM_SAFETY, TEXTGEN_TYPES, CHAT_COMPLETION_SOURCES, AVATAR_WIDTH, AVATAR_HEIGHT } = require('./src/constants');
|
||||||
|
|
||||||
|
@ -1774,93 +1711,6 @@ async function sendAI21Request(request, response) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
app.post('/tokenize_via_api', jsonParser, async function (request, response) {
|
|
||||||
if (!request.body) {
|
|
||||||
return response.sendStatus(400);
|
|
||||||
}
|
|
||||||
const text = String(request.body.text) || '';
|
|
||||||
const api = String(request.body.main_api);
|
|
||||||
const baseUrl = String(request.body.url);
|
|
||||||
const legacyApi = Boolean(request.body.legacy_api);
|
|
||||||
|
|
||||||
try {
|
|
||||||
if (api == 'textgenerationwebui') {
|
|
||||||
const args = {
|
|
||||||
method: 'POST',
|
|
||||||
headers: { 'Content-Type': 'application/json' },
|
|
||||||
};
|
|
||||||
|
|
||||||
setAdditionalHeaders(request, args, null);
|
|
||||||
|
|
||||||
// Convert to string + remove trailing slash + /v1 suffix
|
|
||||||
let url = String(baseUrl).replace(/\/$/, '').replace(/\/v1$/, '');
|
|
||||||
|
|
||||||
if (legacyApi) {
|
|
||||||
url += '/v1/token-count';
|
|
||||||
args.body = JSON.stringify({ 'prompt': text });
|
|
||||||
} else {
|
|
||||||
switch (request.body.api_type) {
|
|
||||||
case TEXTGEN_TYPES.TABBY:
|
|
||||||
url += '/v1/token/encode';
|
|
||||||
args.body = JSON.stringify({ 'text': text });
|
|
||||||
break;
|
|
||||||
case TEXTGEN_TYPES.KOBOLDCPP:
|
|
||||||
url += '/api/extra/tokencount';
|
|
||||||
args.body = JSON.stringify({ 'prompt': text });
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
url += '/v1/internal/encode';
|
|
||||||
args.body = JSON.stringify({ 'text': text });
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const result = await fetch(url, args);
|
|
||||||
|
|
||||||
if (!result.ok) {
|
|
||||||
console.log(`API returned error: ${result.status} ${result.statusText}`);
|
|
||||||
return response.send({ error: true });
|
|
||||||
}
|
|
||||||
|
|
||||||
const data = await result.json();
|
|
||||||
const count = legacyApi ? data?.results[0]?.tokens : (data?.length ?? data?.value);
|
|
||||||
const ids = legacyApi ? [] : (data?.tokens ?? []);
|
|
||||||
|
|
||||||
return response.send({ count, ids });
|
|
||||||
}
|
|
||||||
|
|
||||||
else if (api == 'kobold') {
|
|
||||||
const args = {
|
|
||||||
method: 'POST',
|
|
||||||
body: JSON.stringify({ 'prompt': text }),
|
|
||||||
headers: { 'Content-Type': 'application/json' },
|
|
||||||
};
|
|
||||||
|
|
||||||
let url = String(baseUrl).replace(/\/$/, '');
|
|
||||||
url += '/extra/tokencount';
|
|
||||||
|
|
||||||
const result = await fetch(url, args);
|
|
||||||
|
|
||||||
if (!result.ok) {
|
|
||||||
console.log(`API returned error: ${result.status} ${result.statusText}`);
|
|
||||||
return response.send({ error: true });
|
|
||||||
}
|
|
||||||
|
|
||||||
const data = await result.json();
|
|
||||||
const count = data['value'];
|
|
||||||
return response.send({ count: count, ids: [] });
|
|
||||||
}
|
|
||||||
|
|
||||||
else {
|
|
||||||
console.log('Unknown API', api);
|
|
||||||
return response.send({ error: true });
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
console.log(error);
|
|
||||||
return response.send({ error: true });
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Redirect a deprecated API endpoint URL to its replacement. Because fetch, form submissions, and $.ajax follow
|
* Redirect a deprecated API endpoint URL to its replacement. Because fetch, form submissions, and $.ajax follow
|
||||||
* redirects, this is transparent to client-side code.
|
* redirects, this is transparent to client-side code.
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
const { TEXTGEN_TYPES } = require('./constants');
|
||||||
|
const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
|
||||||
|
const { getConfigValue } = require('./util');
|
||||||
|
|
||||||
|
function getMancerHeaders() {
|
||||||
|
const apiKey = readSecret(SECRET_KEYS.MANCER);
|
||||||
|
|
||||||
|
return apiKey ? ({
|
||||||
|
'X-API-KEY': apiKey,
|
||||||
|
'Authorization': `Bearer ${apiKey}`,
|
||||||
|
}) : {};
|
||||||
|
}
|
||||||
|
|
||||||
|
function getAphroditeHeaders() {
|
||||||
|
const apiKey = readSecret(SECRET_KEYS.APHRODITE);
|
||||||
|
|
||||||
|
return apiKey ? ({
|
||||||
|
'X-API-KEY': apiKey,
|
||||||
|
'Authorization': `Bearer ${apiKey}`,
|
||||||
|
}) : {};
|
||||||
|
}
|
||||||
|
|
||||||
|
function getTabbyHeaders() {
|
||||||
|
const apiKey = readSecret(SECRET_KEYS.TABBY);
|
||||||
|
|
||||||
|
return apiKey ? ({
|
||||||
|
'x-api-key': apiKey,
|
||||||
|
'Authorization': `Bearer ${apiKey}`,
|
||||||
|
}) : {};
|
||||||
|
}
|
||||||
|
|
||||||
|
function getOverrideHeaders(urlHost) {
|
||||||
|
const requestOverrides = getConfigValue('requestOverrides', []);
|
||||||
|
const overrideHeaders = requestOverrides?.find((e) => e.hosts?.includes(urlHost))?.headers;
|
||||||
|
if (overrideHeaders && urlHost) {
|
||||||
|
return overrideHeaders;
|
||||||
|
} else {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets additional headers for the request.
|
||||||
|
* @param {object} request Original request body
|
||||||
|
* @param {object} args New request arguments
|
||||||
|
* @param {string|null} server API server for new request
|
||||||
|
*/
|
||||||
|
function setAdditionalHeaders(request, args, server) {
|
||||||
|
let headers;
|
||||||
|
|
||||||
|
switch (request.body.api_type) {
|
||||||
|
case TEXTGEN_TYPES.MANCER:
|
||||||
|
headers = getMancerHeaders();
|
||||||
|
break;
|
||||||
|
case TEXTGEN_TYPES.APHRODITE:
|
||||||
|
headers = getAphroditeHeaders();
|
||||||
|
break;
|
||||||
|
case TEXTGEN_TYPES.TABBY:
|
||||||
|
headers = getTabbyHeaders();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
headers = server ? getOverrideHeaders((new URL(server))?.host) : {};
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
Object.assign(args.headers, headers);
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
getOverrideHeaders,
|
||||||
|
setAdditionalHeaders,
|
||||||
|
};
|
|
@ -6,7 +6,9 @@ const tiktoken = require('@dqbd/tiktoken');
|
||||||
const { Tokenizer } = require('@agnai/web-tokenizers');
|
const { Tokenizer } = require('@agnai/web-tokenizers');
|
||||||
const { convertClaudePrompt } = require('../chat-completion');
|
const { convertClaudePrompt } = require('../chat-completion');
|
||||||
const { readSecret, SECRET_KEYS } = require('./secrets');
|
const { readSecret, SECRET_KEYS } = require('./secrets');
|
||||||
|
const { TEXTGEN_TYPES } = require('../constants');
|
||||||
const { jsonParser } = require('../express-common');
|
const { jsonParser } = require('../express-common');
|
||||||
|
const { setAdditionalHeaders } = require('../additional-headers');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @type {{[key: string]: import("@dqbd/tiktoken").Tiktoken}} Tokenizers cache
|
* @type {{[key: string]: import("@dqbd/tiktoken").Tiktoken}} Tokenizers cache
|
||||||
|
@ -534,6 +536,96 @@ router.post('/openai/count', jsonParser, async function (req, res) {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
router.post('/remote/kobold/count', jsonParser, async function (request, response) {
|
||||||
|
if (!request.body) {
|
||||||
|
return response.sendStatus(400);
|
||||||
|
}
|
||||||
|
const text = String(request.body.text) || '';
|
||||||
|
const baseUrl = String(request.body.url);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const args = {
|
||||||
|
method: 'POST',
|
||||||
|
body: JSON.stringify({ 'prompt': text }),
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
};
|
||||||
|
|
||||||
|
let url = String(baseUrl).replace(/\/$/, '');
|
||||||
|
url += '/extra/tokencount';
|
||||||
|
|
||||||
|
const result = await fetch(url, args);
|
||||||
|
|
||||||
|
if (!result.ok) {
|
||||||
|
console.log(`API returned error: ${result.status} ${result.statusText}`);
|
||||||
|
return response.send({ error: true });
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await result.json();
|
||||||
|
const count = data['value'];
|
||||||
|
return response.send({ count, ids: [] });
|
||||||
|
} catch (error) {
|
||||||
|
console.log(error);
|
||||||
|
return response.send({ error: true });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
router.post('/remote/textgenerationwebui/encode', jsonParser, async function (request, response) {
|
||||||
|
if (!request.body) {
|
||||||
|
return response.sendStatus(400);
|
||||||
|
}
|
||||||
|
const text = String(request.body.text) || '';
|
||||||
|
const baseUrl = String(request.body.url);
|
||||||
|
const legacyApi = Boolean(request.body.legacy_api);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const args = {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
};
|
||||||
|
|
||||||
|
setAdditionalHeaders(request, args, null);
|
||||||
|
|
||||||
|
// Convert to string + remove trailing slash + /v1 suffix
|
||||||
|
let url = String(baseUrl).replace(/\/$/, '').replace(/\/v1$/, '');
|
||||||
|
|
||||||
|
if (legacyApi) {
|
||||||
|
url += '/v1/token-count';
|
||||||
|
args.body = JSON.stringify({ 'prompt': text });
|
||||||
|
} else {
|
||||||
|
switch (request.body.api_type) {
|
||||||
|
case TEXTGEN_TYPES.TABBY:
|
||||||
|
url += '/v1/token/encode';
|
||||||
|
args.body = JSON.stringify({ 'text': text });
|
||||||
|
break;
|
||||||
|
case TEXTGEN_TYPES.KOBOLDCPP:
|
||||||
|
url += '/api/extra/tokencount';
|
||||||
|
args.body = JSON.stringify({ 'prompt': text });
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
url += '/v1/internal/encode';
|
||||||
|
args.body = JSON.stringify({ 'text': text });
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = await fetch(url, args);
|
||||||
|
|
||||||
|
if (!result.ok) {
|
||||||
|
console.log(`API returned error: ${result.status} ${result.statusText}`);
|
||||||
|
return response.send({ error: true });
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await result.json();
|
||||||
|
const count = legacyApi ? data?.results[0]?.tokens : (data?.length ?? data?.value);
|
||||||
|
const ids = legacyApi ? [] : (data?.tokens ?? []);
|
||||||
|
|
||||||
|
return response.send({ count, ids });
|
||||||
|
} catch (error) {
|
||||||
|
console.log(error);
|
||||||
|
return response.send({ error: true });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
TEXT_COMPLETION_MODELS,
|
TEXT_COMPLETION_MODELS,
|
||||||
getTokenizerModel,
|
getTokenizerModel,
|
||||||
|
|
Loading…
Reference in New Issue