Merge pull request #1503 from valadaptive/tokenizers-cleanup

Tokenizers cleanup
This commit is contained in:
Cohee 2023-12-10 16:35:52 +02:00 committed by GitHub
commit ae01e7419f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 416 additions and 281 deletions

View File

@ -874,7 +874,7 @@ async function getStatusKobold() {
const url = '/getstatus'; const url = '/getstatus';
let endpoint = getAPIServerUrl(); let endpoint = api_server;
if (!endpoint) { if (!endpoint) {
console.warn('No endpoint for status check'); console.warn('No endpoint for status check');
@ -922,7 +922,9 @@ async function getStatusKobold() {
async function getStatusTextgen() { async function getStatusTextgen() {
const url = '/api/textgenerationwebui/status'; const url = '/api/textgenerationwebui/status';
let endpoint = getAPIServerUrl(); let endpoint = textgen_settings.type === MANCER ?
MANCER_SERVER :
api_server_textgenerationwebui;
if (!endpoint) { if (!endpoint) {
console.warn('No endpoint for status check'); console.warn('No endpoint for status check');
@ -1002,23 +1004,6 @@ export function resultCheckStatus() {
stopStatusLoading(); stopStatusLoading();
} }
// TODO(valadaptive): remove the usage of this function in the tokenizers code, then remove the function entirely
export function getAPIServerUrl() {
if (main_api == 'textgenerationwebui') {
if (textgen_settings.type === MANCER) {
return MANCER_SERVER;
}
return api_server_textgenerationwebui;
}
if (main_api == 'kobold') {
return api_server;
}
return '';
}
export async function selectCharacterById(id) { export async function selectCharacterById(id) {
if (characters[id] == undefined) { if (characters[id] == undefined) {
return; return;

View File

@ -1,4 +1,4 @@
import { characters, getAPIServerUrl, main_api, nai_settings, online_status, this_chid } from '../script.js'; import { characters, main_api, api_server, api_server_textgenerationwebui, nai_settings, online_status, this_chid } from '../script.js';
import { power_user, registerDebugFunction } from './power-user.js'; import { power_user, registerDebugFunction } from './power-user.js';
import { chat_completion_sources, model_list, oai_settings } from './openai.js'; import { chat_completion_sources, model_list, oai_settings } from './openai.js';
import { groups, selected_group } from './group-chats.js'; import { groups, selected_group } from './group-chats.js';
@ -18,9 +18,11 @@ export const tokenizers = {
LLAMA: 3, LLAMA: 3,
NERD: 4, NERD: 4,
NERD2: 5, NERD2: 5,
API: 6, API_CURRENT: 6,
MISTRAL: 7, MISTRAL: 7,
YI: 8, YI: 8,
API_TEXTGENERATIONWEBUI: 9,
API_KOBOLD: 10,
BEST_MATCH: 99, BEST_MATCH: 99,
}; };
@ -33,6 +35,51 @@ export const SENTENCEPIECE_TOKENIZERS = [
//tokenizers.NERD2, //tokenizers.NERD2,
]; ];
const TOKENIZER_URLS = {
[tokenizers.GPT2]: {
encode: '/api/tokenizers/gpt2/encode',
decode: '/api/tokenizers/gpt2/decode',
count: '/api/tokenizers/gpt2/encode',
},
[tokenizers.OPENAI]: {
encode: '/api/tokenizers/openai/encode',
decode: '/api/tokenizers/openai/decode',
count: '/api/tokenizers/openai/encode',
},
[tokenizers.LLAMA]: {
encode: '/api/tokenizers/llama/encode',
decode: '/api/tokenizers/llama/decode',
count: '/api/tokenizers/llama/encode',
},
[tokenizers.NERD]: {
encode: '/api/tokenizers/nerdstash/encode',
decode: '/api/tokenizers/nerdstash/decode',
count: '/api/tokenizers/nerdstash/encode',
},
[tokenizers.NERD2]: {
encode: '/api/tokenizers/nerdstash_v2/encode',
decode: '/api/tokenizers/nerdstash_v2/decode',
count: '/api/tokenizers/nerdstash_v2/encode',
},
[tokenizers.API_KOBOLD]: {
count: '/api/tokenizers/remote/kobold/count',
},
[tokenizers.MISTRAL]: {
encode: '/api/tokenizers/mistral/encode',
decode: '/api/tokenizers/mistral/decode',
count: '/api/tokenizers/mistral/encode',
},
[tokenizers.YI]: {
encode: '/api/tokenizers/yi/encode',
decode: '/api/tokenizers/yi/decode',
count: '/api/tokenizers/yi/encode',
},
[tokenizers.API_TEXTGENERATIONWEBUI]: {
encode: '/api/tokenizers/remote/textgenerationwebui/encode',
count: '/api/tokenizers/remote/textgenerationwebui/encode',
},
};
const objectStore = new localforage.createInstance({ name: 'SillyTavern_ChatCompletions' }); const objectStore = new localforage.createInstance({ name: 'SillyTavern_ChatCompletions' });
let tokenCache = {}; let tokenCache = {};
@ -92,7 +139,18 @@ export function getFriendlyTokenizerName(forApi) {
if (forApi !== 'openai' && tokenizerId === tokenizers.BEST_MATCH) { if (forApi !== 'openai' && tokenizerId === tokenizers.BEST_MATCH) {
tokenizerId = getTokenizerBestMatch(forApi); tokenizerId = getTokenizerBestMatch(forApi);
tokenizerName = $(`#tokenizer option[value="${tokenizerId}"]`).text();
switch (tokenizerId) {
case tokenizers.API_KOBOLD:
tokenizerName = 'API (KoboldAI Classic)';
break;
case tokenizers.API_TEXTGENERATIONWEBUI:
tokenizerName = 'API (Text Completion)';
break;
default:
tokenizerName = $(`#tokenizer option[value="${tokenizerId}"]`).text();
break;
}
} }
tokenizerName = forApi == 'openai' tokenizerName = forApi == 'openai'
@ -135,11 +193,11 @@ export function getTokenizerBestMatch(forApi) {
if (!hasTokenizerError && isConnected) { if (!hasTokenizerError && isConnected) {
if (forApi === 'kobold' && kai_flags.can_use_tokenization) { if (forApi === 'kobold' && kai_flags.can_use_tokenization) {
return tokenizers.API; return tokenizers.API_KOBOLD;
} }
if (forApi === 'textgenerationwebui' && isTokenizerSupported) { if (forApi === 'textgenerationwebui' && isTokenizerSupported) {
return tokenizers.API; return tokenizers.API_TEXTGENERATIONWEBUI;
} }
} }
@ -149,34 +207,42 @@ export function getTokenizerBestMatch(forApi) {
return tokenizers.NONE; return tokenizers.NONE;
} }
// Get the current remote tokenizer API based on the current text generation API.
function currentRemoteTokenizerAPI() {
switch (main_api) {
case 'kobold':
return tokenizers.API_KOBOLD;
case 'textgenerationwebui':
return tokenizers.API_TEXTGENERATIONWEBUI;
default:
return tokenizers.NONE;
}
}
/** /**
* Calls the underlying tokenizer model to the token count for a string. * Calls the underlying tokenizer model to the token count for a string.
* @param {number} type Tokenizer type. * @param {number} type Tokenizer type.
* @param {string} str String to tokenize. * @param {string} str String to tokenize.
* @param {number} padding Number of padding tokens.
* @returns {number} Token count. * @returns {number} Token count.
*/ */
function callTokenizer(type, str, padding) { function callTokenizer(type, str) {
if (type === tokenizers.NONE) return guesstimate(str);
switch (type) { switch (type) {
case tokenizers.NONE: case tokenizers.API_CURRENT:
return guesstimate(str) + padding; return callTokenizer(currentRemoteTokenizerAPI(), str);
case tokenizers.GPT2: case tokenizers.API_KOBOLD:
return countTokensRemote('/api/tokenizers/gpt2/encode', str, padding); return countTokensFromKoboldAPI(str);
case tokenizers.LLAMA: case tokenizers.API_TEXTGENERATIONWEBUI:
return countTokensRemote('/api/tokenizers/llama/encode', str, padding); return countTokensFromTextgenAPI(str);
case tokenizers.NERD: default: {
return countTokensRemote('/api/tokenizers/nerdstash/encode', str, padding); const endpointUrl = TOKENIZER_URLS[type]?.count;
case tokenizers.NERD2: if (!endpointUrl) {
return countTokensRemote('/api/tokenizers/nerdstash_v2/encode', str, padding); console.warn('Unknown tokenizer type', type);
case tokenizers.MISTRAL: return apiFailureTokenCount(str);
return countTokensRemote('/api/tokenizers/mistral/encode', str, padding); }
case tokenizers.YI: return countTokensFromServer(endpointUrl, str);
return countTokensRemote('/api/tokenizers/yi/encode', str, padding); }
case tokenizers.API:
return countTokensRemote('/tokenize_via_api', str, padding);
default:
console.warn('Unknown tokenizer type', type);
return callTokenizer(tokenizers.NONE, str, padding);
} }
} }
@ -219,7 +285,7 @@ export function getTokenCount(str, padding = undefined) {
return cacheObject[cacheKey]; return cacheObject[cacheKey];
} }
const result = callTokenizer(tokenizerType, str, padding); const result = callTokenizer(tokenizerType, str) + padding;
if (isNaN(result)) { if (isNaN(result)) {
console.warn('Token count calculation returned NaN'); console.warn('Token count calculation returned NaN');
@ -391,76 +457,131 @@ function getTokenCacheObject() {
return tokenCache[String(chatId)]; return tokenCache[String(chatId)];
} }
function getRemoteTokenizationParams(str) {
return {
text: str,
main_api,
api_type: textgen_settings.type,
url: getAPIServerUrl(),
legacy_api: main_api === 'textgenerationwebui' &&
textgen_settings.legacy_api &&
textgen_settings.type !== MANCER,
};
}
/** /**
* Counts token using the remote server API. * Count tokens using the server API.
* @param {string} endpoint API endpoint. * @param {string} endpoint API endpoint.
* @param {string} str String to tokenize. * @param {string} str String to tokenize.
* @param {number} padding Number of padding tokens. * @returns {number} Token count.
* @returns {number} Token count with padding.
*/ */
function countTokensRemote(endpoint, str, padding) { function countTokensFromServer(endpoint, str) {
let tokenCount = 0; let tokenCount = 0;
jQuery.ajax({ jQuery.ajax({
async: false, async: false,
type: 'POST', type: 'POST',
url: endpoint, url: endpoint,
data: JSON.stringify(getRemoteTokenizationParams(str)), data: JSON.stringify({ text: str }),
dataType: 'json', dataType: 'json',
contentType: 'application/json', contentType: 'application/json',
success: function (data) { success: function (data) {
if (typeof data.count === 'number') { if (typeof data.count === 'number') {
tokenCount = data.count; tokenCount = data.count;
} else { } else {
tokenCount = guesstimate(str); tokenCount = apiFailureTokenCount(str);
console.error('Error counting tokens');
if (!sessionStorage.getItem(TOKENIZER_WARNING_KEY)) {
toastr.warning(
'Your selected API doesn\'t support the tokenization endpoint. Using estimated counts.',
'Error counting tokens',
{ timeOut: 10000, preventDuplicates: true },
);
sessionStorage.setItem(TOKENIZER_WARNING_KEY, String(true));
}
} }
}, },
}); });
return tokenCount + padding; return tokenCount;
}
/**
* Count tokens using the AI provider's API.
* @param {string} str String to tokenize.
* @returns {number} Token count.
*/
function countTokensFromKoboldAPI(str) {
let tokenCount = 0;
jQuery.ajax({
async: false,
type: 'POST',
url: TOKENIZER_URLS[tokenizers.API_KOBOLD].count,
data: JSON.stringify({
text: str,
url: api_server,
}),
dataType: 'json',
contentType: 'application/json',
success: function (data) {
if (typeof data.count === 'number') {
tokenCount = data.count;
} else {
tokenCount = apiFailureTokenCount(str);
}
},
});
return tokenCount;
}
function getTextgenAPITokenizationParams(str) {
return {
text: str,
api_type: textgen_settings.type,
url: api_server_textgenerationwebui,
legacy_api:
textgen_settings.legacy_api &&
textgen_settings.type !== MANCER,
};
}
/**
* Count tokens using the AI provider's API.
* @param {string} str String to tokenize.
* @returns {number} Token count.
*/
function countTokensFromTextgenAPI(str) {
let tokenCount = 0;
jQuery.ajax({
async: false,
type: 'POST',
url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].count,
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
dataType: 'json',
contentType: 'application/json',
success: function (data) {
if (typeof data.count === 'number') {
tokenCount = data.count;
} else {
tokenCount = apiFailureTokenCount(str);
}
},
});
return tokenCount;
}
function apiFailureTokenCount(str) {
console.error('Error counting tokens');
if (!sessionStorage.getItem(TOKENIZER_WARNING_KEY)) {
toastr.warning(
'Your selected API doesn\'t support the tokenization endpoint. Using estimated counts.',
'Error counting tokens',
{ timeOut: 10000, preventDuplicates: true },
);
sessionStorage.setItem(TOKENIZER_WARNING_KEY, String(true));
}
return guesstimate(str);
} }
/** /**
* Calls the underlying tokenizer model to encode a string to tokens. * Calls the underlying tokenizer model to encode a string to tokens.
* @param {string} endpoint API endpoint. * @param {string} endpoint API endpoint.
* @param {string} str String to tokenize. * @param {string} str String to tokenize.
* @param {string} model Tokenizer model.
* @returns {number[]} Array of token ids. * @returns {number[]} Array of token ids.
*/ */
function getTextTokensRemote(endpoint, str, model = '') { function getTextTokensFromServer(endpoint, str) {
if (model) {
endpoint += `?model=${model}`;
}
let ids = []; let ids = [];
jQuery.ajax({ jQuery.ajax({
async: false, async: false,
type: 'POST', type: 'POST',
url: endpoint, url: endpoint,
data: JSON.stringify(getRemoteTokenizationParams(str)), data: JSON.stringify({ text: str }),
dataType: 'json', dataType: 'json',
contentType: 'application/json', contentType: 'application/json',
success: function (data) { success: function (data) {
@ -475,16 +596,33 @@ function getTextTokensRemote(endpoint, str, model = '') {
return ids; return ids;
} }
/**
* Calls the AI provider's tokenize API to encode a string to tokens.
* @param {string} str String to tokenize.
* @returns {number[]} Array of token ids.
*/
function getTextTokensFromTextgenAPI(str) {
let ids = [];
jQuery.ajax({
async: false,
type: 'POST',
url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].encode,
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
dataType: 'json',
contentType: 'application/json',
success: function (data) {
ids = data.ids;
},
});
return ids;
}
/** /**
* Calls the underlying tokenizer model to decode token ids to text. * Calls the underlying tokenizer model to decode token ids to text.
* @param {string} endpoint API endpoint. * @param {string} endpoint API endpoint.
* @param {number[]} ids Array of token ids * @param {number[]} ids Array of token ids
*/ */
function decodeTextTokensRemote(endpoint, ids, model = '') { function decodeTextTokensFromServer(endpoint, ids) {
if (model) {
endpoint += `?model=${model}`;
}
let text = ''; let text = '';
jQuery.ajax({ jQuery.ajax({
async: false, async: false,
@ -501,64 +639,62 @@ function decodeTextTokensRemote(endpoint, ids, model = '') {
} }
/** /**
* Encodes a string to tokens using the remote server API. * Encodes a string to tokens using the server API.
* @param {number} tokenizerType Tokenizer type. * @param {number} tokenizerType Tokenizer type.
* @param {string} str String to tokenize. * @param {string} str String to tokenize.
* @returns {number[]} Array of token ids. * @returns {number[]} Array of token ids.
*/ */
export function getTextTokens(tokenizerType, str) { export function getTextTokens(tokenizerType, str) {
switch (tokenizerType) { switch (tokenizerType) {
case tokenizers.GPT2: case tokenizers.API_CURRENT:
return getTextTokensRemote('/api/tokenizers/gpt2/encode', str); return getTextTokens(currentRemoteTokenizerAPI(), str);
case tokenizers.LLAMA: case tokenizers.API_TEXTGENERATIONWEBUI:
return getTextTokensRemote('/api/tokenizers/llama/encode', str); return getTextTokensFromTextgenAPI(str);
case tokenizers.NERD: default: {
return getTextTokensRemote('/api/tokenizers/nerdstash/encode', str); const tokenizerEndpoints = TOKENIZER_URLS[tokenizerType];
case tokenizers.NERD2: if (!tokenizerEndpoints) {
return getTextTokensRemote('/api/tokenizers/nerdstash_v2/encode', str); apiFailureTokenCount(str);
case tokenizers.MISTRAL: console.warn('Unknown tokenizer type', tokenizerType);
return getTextTokensRemote('/api/tokenizers/mistral/encode', str); return [];
case tokenizers.YI: }
return getTextTokensRemote('/api/tokenizers/yi/encode', str); let endpointUrl = tokenizerEndpoints.encode;
case tokenizers.OPENAI: { if (!endpointUrl) {
const model = getTokenizerModel(); apiFailureTokenCount(str);
return getTextTokensRemote('/api/tokenizers/openai/encode', str, model); console.warn('This tokenizer type does not support encoding', tokenizerType);
return [];
}
if (tokenizerType === tokenizers.OPENAI) {
endpointUrl += `?model=${getTokenizerModel()}`;
}
return getTextTokensFromServer(endpointUrl, str);
} }
case tokenizers.API:
return getTextTokensRemote('/tokenize_via_api', str);
default:
console.warn('Calling getTextTokens with unsupported tokenizer type', tokenizerType);
return [];
} }
} }
/** /**
* Decodes token ids to text using the remote server API. * Decodes token ids to text using the server API.
* @param {number} tokenizerType Tokenizer type. * @param {number} tokenizerType Tokenizer type.
* @param {number[]} ids Array of token ids * @param {number[]} ids Array of token ids
*/ */
export function decodeTextTokens(tokenizerType, ids) { export function decodeTextTokens(tokenizerType, ids) {
switch (tokenizerType) { // Currently, neither remote API can decode, but this may change in the future. Put this guard here to be safe
case tokenizers.GPT2: if (tokenizerType === tokenizers.API_CURRENT) {
return decodeTextTokensRemote('/api/tokenizers/gpt2/decode', ids); return decodeTextTokens(tokenizers.NONE, ids);
case tokenizers.LLAMA:
return decodeTextTokensRemote('/api/tokenizers/llama/decode', ids);
case tokenizers.NERD:
return decodeTextTokensRemote('/api/tokenizers/nerdstash/decode', ids);
case tokenizers.NERD2:
return decodeTextTokensRemote('/api/tokenizers/nerdstash_v2/decode', ids);
case tokenizers.MISTRAL:
return decodeTextTokensRemote('/api/tokenizers/mistral/decode', ids);
case tokenizers.YI:
return decodeTextTokensRemote('/api/tokenizers/yi/decode', ids);
case tokenizers.OPENAI: {
const model = getTokenizerModel();
return decodeTextTokensRemote('/api/tokenizers/openai/decode', ids, model);
}
default:
console.warn('Calling decodeTextTokens with unsupported tokenizer type', tokenizerType);
return '';
} }
const tokenizerEndpoints = TOKENIZER_URLS[tokenizerType];
if (!tokenizerEndpoints) {
console.warn('Unknown tokenizer type', tokenizerType);
return [];
}
let endpointUrl = tokenizerEndpoints.decode;
if (!endpointUrl) {
console.warn('This tokenizer type does not support decoding', tokenizerType);
return [];
}
if (tokenizerType === tokenizers.OPENAI) {
endpointUrl += `?model=${getTokenizerModel()}`;
}
return decodeTextTokensFromServer(endpointUrl, ids);
} }
export async function initTokenizers() { export async function initTokenizers() {

152
server.js
View File

@ -49,6 +49,7 @@ const { delay, getVersion, getConfigValue, color, uuidv4, tryParse, clientRelati
const { ensureThumbnailCache } = require('./src/endpoints/thumbnails'); const { ensureThumbnailCache } = require('./src/endpoints/thumbnails');
const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS, getSentencepiceTokenizer, sentencepieceTokenizers } = require('./src/endpoints/tokenizers'); const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS, getSentencepiceTokenizer, sentencepieceTokenizers } = require('./src/endpoints/tokenizers');
const { convertClaudePrompt } = require('./src/chat-completion'); const { convertClaudePrompt } = require('./src/chat-completion');
const { getOverrideHeaders, setAdditionalHeaders } = require('./src/additional-headers');
// Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0. // Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0.
// https://github.com/nodejs/node/issues/47822#issuecomment-1564708870 // https://github.com/nodejs/node/issues/47822#issuecomment-1564708870
@ -119,70 +120,6 @@ const listen = getConfigValue('listen', false);
const API_OPENAI = 'https://api.openai.com/v1'; const API_OPENAI = 'https://api.openai.com/v1';
const API_CLAUDE = 'https://api.anthropic.com/v1'; const API_CLAUDE = 'https://api.anthropic.com/v1';
function getMancerHeaders() {
const apiKey = readSecret(SECRET_KEYS.MANCER);
return apiKey ? ({
'X-API-KEY': apiKey,
'Authorization': `Bearer ${apiKey}`,
}) : {};
}
function getAphroditeHeaders() {
const apiKey = readSecret(SECRET_KEYS.APHRODITE);
return apiKey ? ({
'X-API-KEY': apiKey,
'Authorization': `Bearer ${apiKey}`,
}) : {};
}
function getTabbyHeaders() {
const apiKey = readSecret(SECRET_KEYS.TABBY);
return apiKey ? ({
'x-api-key': apiKey,
'Authorization': `Bearer ${apiKey}`,
}) : {};
}
function getOverrideHeaders(urlHost) {
const requestOverrides = getConfigValue('requestOverrides', []);
const overrideHeaders = requestOverrides?.find((e) => e.hosts?.includes(urlHost))?.headers;
if (overrideHeaders && urlHost) {
return overrideHeaders;
} else {
return {};
}
}
/**
* Sets additional headers for the request.
* @param {object} request Original request body
* @param {object} args New request arguments
* @param {string|null} server API server for new request
*/
function setAdditionalHeaders(request, args, server) {
let headers;
switch (request.body.api_type) {
case TEXTGEN_TYPES.MANCER:
headers = getMancerHeaders();
break;
case TEXTGEN_TYPES.APHRODITE:
headers = getAphroditeHeaders();
break;
case TEXTGEN_TYPES.TABBY:
headers = getTabbyHeaders();
break;
default:
headers = server ? getOverrideHeaders((new URL(server))?.host) : {};
break;
}
Object.assign(args.headers, headers);
}
const SETTINGS_FILE = './public/settings.json'; const SETTINGS_FILE = './public/settings.json';
const { DIRECTORIES, UPLOADS_PATH, PALM_SAFETY, TEXTGEN_TYPES, CHAT_COMPLETION_SOURCES, AVATAR_WIDTH, AVATAR_HEIGHT } = require('./src/constants'); const { DIRECTORIES, UPLOADS_PATH, PALM_SAFETY, TEXTGEN_TYPES, CHAT_COMPLETION_SOURCES, AVATAR_WIDTH, AVATAR_HEIGHT } = require('./src/constants');
@ -1774,93 +1711,6 @@ async function sendAI21Request(request, response) {
} }
app.post('/tokenize_via_api', jsonParser, async function (request, response) {
if (!request.body) {
return response.sendStatus(400);
}
const text = String(request.body.text) || '';
const api = String(request.body.main_api);
const baseUrl = String(request.body.url);
const legacyApi = Boolean(request.body.legacy_api);
try {
if (api == 'textgenerationwebui') {
const args = {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
};
setAdditionalHeaders(request, args, null);
// Convert to string + remove trailing slash + /v1 suffix
let url = String(baseUrl).replace(/\/$/, '').replace(/\/v1$/, '');
if (legacyApi) {
url += '/v1/token-count';
args.body = JSON.stringify({ 'prompt': text });
} else {
switch (request.body.api_type) {
case TEXTGEN_TYPES.TABBY:
url += '/v1/token/encode';
args.body = JSON.stringify({ 'text': text });
break;
case TEXTGEN_TYPES.KOBOLDCPP:
url += '/api/extra/tokencount';
args.body = JSON.stringify({ 'prompt': text });
break;
default:
url += '/v1/internal/encode';
args.body = JSON.stringify({ 'text': text });
break;
}
}
const result = await fetch(url, args);
if (!result.ok) {
console.log(`API returned error: ${result.status} ${result.statusText}`);
return response.send({ error: true });
}
const data = await result.json();
const count = legacyApi ? data?.results[0]?.tokens : (data?.length ?? data?.value);
const ids = legacyApi ? [] : (data?.tokens ?? []);
return response.send({ count, ids });
}
else if (api == 'kobold') {
const args = {
method: 'POST',
body: JSON.stringify({ 'prompt': text }),
headers: { 'Content-Type': 'application/json' },
};
let url = String(baseUrl).replace(/\/$/, '');
url += '/extra/tokencount';
const result = await fetch(url, args);
if (!result.ok) {
console.log(`API returned error: ${result.status} ${result.statusText}`);
return response.send({ error: true });
}
const data = await result.json();
const count = data['value'];
return response.send({ count: count, ids: [] });
}
else {
console.log('Unknown API', api);
return response.send({ error: true });
}
} catch (error) {
console.log(error);
return response.send({ error: true });
}
});
/** /**
* Redirect a deprecated API endpoint URL to its replacement. Because fetch, form submissions, and $.ajax follow * Redirect a deprecated API endpoint URL to its replacement. Because fetch, form submissions, and $.ajax follow
* redirects, this is transparent to client-side code. * redirects, this is transparent to client-side code.

72
src/additional-headers.js Normal file
View File

@ -0,0 +1,72 @@
const { TEXTGEN_TYPES } = require('./constants');
const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
const { getConfigValue } = require('./util');
function getMancerHeaders() {
const apiKey = readSecret(SECRET_KEYS.MANCER);
return apiKey ? ({
'X-API-KEY': apiKey,
'Authorization': `Bearer ${apiKey}`,
}) : {};
}
function getAphroditeHeaders() {
const apiKey = readSecret(SECRET_KEYS.APHRODITE);
return apiKey ? ({
'X-API-KEY': apiKey,
'Authorization': `Bearer ${apiKey}`,
}) : {};
}
function getTabbyHeaders() {
const apiKey = readSecret(SECRET_KEYS.TABBY);
return apiKey ? ({
'x-api-key': apiKey,
'Authorization': `Bearer ${apiKey}`,
}) : {};
}
function getOverrideHeaders(urlHost) {
const requestOverrides = getConfigValue('requestOverrides', []);
const overrideHeaders = requestOverrides?.find((e) => e.hosts?.includes(urlHost))?.headers;
if (overrideHeaders && urlHost) {
return overrideHeaders;
} else {
return {};
}
}
/**
* Sets additional headers for the request.
* @param {object} request Original request body
* @param {object} args New request arguments
* @param {string|null} server API server for new request
*/
function setAdditionalHeaders(request, args, server) {
let headers;
switch (request.body.api_type) {
case TEXTGEN_TYPES.MANCER:
headers = getMancerHeaders();
break;
case TEXTGEN_TYPES.APHRODITE:
headers = getAphroditeHeaders();
break;
case TEXTGEN_TYPES.TABBY:
headers = getTabbyHeaders();
break;
default:
headers = server ? getOverrideHeaders((new URL(server))?.host) : {};
break;
}
Object.assign(args.headers, headers);
}
module.exports = {
getOverrideHeaders,
setAdditionalHeaders,
};

View File

@ -6,7 +6,9 @@ const tiktoken = require('@dqbd/tiktoken');
const { Tokenizer } = require('@agnai/web-tokenizers'); const { Tokenizer } = require('@agnai/web-tokenizers');
const { convertClaudePrompt } = require('../chat-completion'); const { convertClaudePrompt } = require('../chat-completion');
const { readSecret, SECRET_KEYS } = require('./secrets'); const { readSecret, SECRET_KEYS } = require('./secrets');
const { TEXTGEN_TYPES } = require('../constants');
const { jsonParser } = require('../express-common'); const { jsonParser } = require('../express-common');
const { setAdditionalHeaders } = require('../additional-headers');
/** /**
* @type {{[key: string]: import("@dqbd/tiktoken").Tiktoken}} Tokenizers cache * @type {{[key: string]: import("@dqbd/tiktoken").Tiktoken}} Tokenizers cache
@ -534,6 +536,96 @@ router.post('/openai/count', jsonParser, async function (req, res) {
} }
}); });
router.post('/remote/kobold/count', jsonParser, async function (request, response) {
if (!request.body) {
return response.sendStatus(400);
}
const text = String(request.body.text) || '';
const baseUrl = String(request.body.url);
try {
const args = {
method: 'POST',
body: JSON.stringify({ 'prompt': text }),
headers: { 'Content-Type': 'application/json' },
};
let url = String(baseUrl).replace(/\/$/, '');
url += '/extra/tokencount';
const result = await fetch(url, args);
if (!result.ok) {
console.log(`API returned error: ${result.status} ${result.statusText}`);
return response.send({ error: true });
}
const data = await result.json();
const count = data['value'];
return response.send({ count, ids: [] });
} catch (error) {
console.log(error);
return response.send({ error: true });
}
});
router.post('/remote/textgenerationwebui/encode', jsonParser, async function (request, response) {
if (!request.body) {
return response.sendStatus(400);
}
const text = String(request.body.text) || '';
const baseUrl = String(request.body.url);
const legacyApi = Boolean(request.body.legacy_api);
try {
const args = {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
};
setAdditionalHeaders(request, args, null);
// Convert to string + remove trailing slash + /v1 suffix
let url = String(baseUrl).replace(/\/$/, '').replace(/\/v1$/, '');
if (legacyApi) {
url += '/v1/token-count';
args.body = JSON.stringify({ 'prompt': text });
} else {
switch (request.body.api_type) {
case TEXTGEN_TYPES.TABBY:
url += '/v1/token/encode';
args.body = JSON.stringify({ 'text': text });
break;
case TEXTGEN_TYPES.KOBOLDCPP:
url += '/api/extra/tokencount';
args.body = JSON.stringify({ 'prompt': text });
break;
default:
url += '/v1/internal/encode';
args.body = JSON.stringify({ 'text': text });
break;
}
}
const result = await fetch(url, args);
if (!result.ok) {
console.log(`API returned error: ${result.status} ${result.statusText}`);
return response.send({ error: true });
}
const data = await result.json();
const count = legacyApi ? data?.results[0]?.tokens : (data?.length ?? data?.value);
const ids = legacyApi ? [] : (data?.tokens ?? []);
return response.send({ count, ids });
} catch (error) {
console.log(error);
return response.send({ error: true });
}
});
module.exports = { module.exports = {
TEXT_COMPLETION_MODELS, TEXT_COMPLETION_MODELS,
getTokenizerModel, getTokenizerModel,