Merge pull request #1503 from valadaptive/tokenizers-cleanup
Tokenizers cleanup
This commit is contained in:
commit
ae01e7419f
|
@ -874,7 +874,7 @@ async function getStatusKobold() {
|
|||
|
||||
const url = '/getstatus';
|
||||
|
||||
let endpoint = getAPIServerUrl();
|
||||
let endpoint = api_server;
|
||||
|
||||
if (!endpoint) {
|
||||
console.warn('No endpoint for status check');
|
||||
|
@ -922,7 +922,9 @@ async function getStatusKobold() {
|
|||
async function getStatusTextgen() {
|
||||
const url = '/api/textgenerationwebui/status';
|
||||
|
||||
let endpoint = getAPIServerUrl();
|
||||
let endpoint = textgen_settings.type === MANCER ?
|
||||
MANCER_SERVER :
|
||||
api_server_textgenerationwebui;
|
||||
|
||||
if (!endpoint) {
|
||||
console.warn('No endpoint for status check');
|
||||
|
@ -1002,23 +1004,6 @@ export function resultCheckStatus() {
|
|||
stopStatusLoading();
|
||||
}
|
||||
|
||||
// TODO(valadaptive): remove the usage of this function in the tokenizers code, then remove the function entirely
|
||||
export function getAPIServerUrl() {
|
||||
if (main_api == 'textgenerationwebui') {
|
||||
if (textgen_settings.type === MANCER) {
|
||||
return MANCER_SERVER;
|
||||
}
|
||||
|
||||
return api_server_textgenerationwebui;
|
||||
}
|
||||
|
||||
if (main_api == 'kobold') {
|
||||
return api_server;
|
||||
}
|
||||
|
||||
return '';
|
||||
}
|
||||
|
||||
export async function selectCharacterById(id) {
|
||||
if (characters[id] == undefined) {
|
||||
return;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import { characters, getAPIServerUrl, main_api, nai_settings, online_status, this_chid } from '../script.js';
|
||||
import { characters, main_api, api_server, api_server_textgenerationwebui, nai_settings, online_status, this_chid } from '../script.js';
|
||||
import { power_user, registerDebugFunction } from './power-user.js';
|
||||
import { chat_completion_sources, model_list, oai_settings } from './openai.js';
|
||||
import { groups, selected_group } from './group-chats.js';
|
||||
|
@ -18,9 +18,11 @@ export const tokenizers = {
|
|||
LLAMA: 3,
|
||||
NERD: 4,
|
||||
NERD2: 5,
|
||||
API: 6,
|
||||
API_CURRENT: 6,
|
||||
MISTRAL: 7,
|
||||
YI: 8,
|
||||
API_TEXTGENERATIONWEBUI: 9,
|
||||
API_KOBOLD: 10,
|
||||
BEST_MATCH: 99,
|
||||
};
|
||||
|
||||
|
@ -33,6 +35,51 @@ export const SENTENCEPIECE_TOKENIZERS = [
|
|||
//tokenizers.NERD2,
|
||||
];
|
||||
|
||||
const TOKENIZER_URLS = {
|
||||
[tokenizers.GPT2]: {
|
||||
encode: '/api/tokenizers/gpt2/encode',
|
||||
decode: '/api/tokenizers/gpt2/decode',
|
||||
count: '/api/tokenizers/gpt2/encode',
|
||||
},
|
||||
[tokenizers.OPENAI]: {
|
||||
encode: '/api/tokenizers/openai/encode',
|
||||
decode: '/api/tokenizers/openai/decode',
|
||||
count: '/api/tokenizers/openai/encode',
|
||||
},
|
||||
[tokenizers.LLAMA]: {
|
||||
encode: '/api/tokenizers/llama/encode',
|
||||
decode: '/api/tokenizers/llama/decode',
|
||||
count: '/api/tokenizers/llama/encode',
|
||||
},
|
||||
[tokenizers.NERD]: {
|
||||
encode: '/api/tokenizers/nerdstash/encode',
|
||||
decode: '/api/tokenizers/nerdstash/decode',
|
||||
count: '/api/tokenizers/nerdstash/encode',
|
||||
},
|
||||
[tokenizers.NERD2]: {
|
||||
encode: '/api/tokenizers/nerdstash_v2/encode',
|
||||
decode: '/api/tokenizers/nerdstash_v2/decode',
|
||||
count: '/api/tokenizers/nerdstash_v2/encode',
|
||||
},
|
||||
[tokenizers.API_KOBOLD]: {
|
||||
count: '/api/tokenizers/remote/kobold/count',
|
||||
},
|
||||
[tokenizers.MISTRAL]: {
|
||||
encode: '/api/tokenizers/mistral/encode',
|
||||
decode: '/api/tokenizers/mistral/decode',
|
||||
count: '/api/tokenizers/mistral/encode',
|
||||
},
|
||||
[tokenizers.YI]: {
|
||||
encode: '/api/tokenizers/yi/encode',
|
||||
decode: '/api/tokenizers/yi/decode',
|
||||
count: '/api/tokenizers/yi/encode',
|
||||
},
|
||||
[tokenizers.API_TEXTGENERATIONWEBUI]: {
|
||||
encode: '/api/tokenizers/remote/textgenerationwebui/encode',
|
||||
count: '/api/tokenizers/remote/textgenerationwebui/encode',
|
||||
},
|
||||
};
|
||||
|
||||
const objectStore = new localforage.createInstance({ name: 'SillyTavern_ChatCompletions' });
|
||||
|
||||
let tokenCache = {};
|
||||
|
@ -92,7 +139,18 @@ export function getFriendlyTokenizerName(forApi) {
|
|||
|
||||
if (forApi !== 'openai' && tokenizerId === tokenizers.BEST_MATCH) {
|
||||
tokenizerId = getTokenizerBestMatch(forApi);
|
||||
|
||||
switch (tokenizerId) {
|
||||
case tokenizers.API_KOBOLD:
|
||||
tokenizerName = 'API (KoboldAI Classic)';
|
||||
break;
|
||||
case tokenizers.API_TEXTGENERATIONWEBUI:
|
||||
tokenizerName = 'API (Text Completion)';
|
||||
break;
|
||||
default:
|
||||
tokenizerName = $(`#tokenizer option[value="${tokenizerId}"]`).text();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
tokenizerName = forApi == 'openai'
|
||||
|
@ -135,11 +193,11 @@ export function getTokenizerBestMatch(forApi) {
|
|||
|
||||
if (!hasTokenizerError && isConnected) {
|
||||
if (forApi === 'kobold' && kai_flags.can_use_tokenization) {
|
||||
return tokenizers.API;
|
||||
return tokenizers.API_KOBOLD;
|
||||
}
|
||||
|
||||
if (forApi === 'textgenerationwebui' && isTokenizerSupported) {
|
||||
return tokenizers.API;
|
||||
return tokenizers.API_TEXTGENERATIONWEBUI;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -149,34 +207,42 @@ export function getTokenizerBestMatch(forApi) {
|
|||
return tokenizers.NONE;
|
||||
}
|
||||
|
||||
// Get the current remote tokenizer API based on the current text generation API.
|
||||
function currentRemoteTokenizerAPI() {
|
||||
switch (main_api) {
|
||||
case 'kobold':
|
||||
return tokenizers.API_KOBOLD;
|
||||
case 'textgenerationwebui':
|
||||
return tokenizers.API_TEXTGENERATIONWEBUI;
|
||||
default:
|
||||
return tokenizers.NONE;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calls the underlying tokenizer model to the token count for a string.
|
||||
* @param {number} type Tokenizer type.
|
||||
* @param {string} str String to tokenize.
|
||||
* @param {number} padding Number of padding tokens.
|
||||
* @returns {number} Token count.
|
||||
*/
|
||||
function callTokenizer(type, str, padding) {
|
||||
function callTokenizer(type, str) {
|
||||
if (type === tokenizers.NONE) return guesstimate(str);
|
||||
|
||||
switch (type) {
|
||||
case tokenizers.NONE:
|
||||
return guesstimate(str) + padding;
|
||||
case tokenizers.GPT2:
|
||||
return countTokensRemote('/api/tokenizers/gpt2/encode', str, padding);
|
||||
case tokenizers.LLAMA:
|
||||
return countTokensRemote('/api/tokenizers/llama/encode', str, padding);
|
||||
case tokenizers.NERD:
|
||||
return countTokensRemote('/api/tokenizers/nerdstash/encode', str, padding);
|
||||
case tokenizers.NERD2:
|
||||
return countTokensRemote('/api/tokenizers/nerdstash_v2/encode', str, padding);
|
||||
case tokenizers.MISTRAL:
|
||||
return countTokensRemote('/api/tokenizers/mistral/encode', str, padding);
|
||||
case tokenizers.YI:
|
||||
return countTokensRemote('/api/tokenizers/yi/encode', str, padding);
|
||||
case tokenizers.API:
|
||||
return countTokensRemote('/tokenize_via_api', str, padding);
|
||||
default:
|
||||
case tokenizers.API_CURRENT:
|
||||
return callTokenizer(currentRemoteTokenizerAPI(), str);
|
||||
case tokenizers.API_KOBOLD:
|
||||
return countTokensFromKoboldAPI(str);
|
||||
case tokenizers.API_TEXTGENERATIONWEBUI:
|
||||
return countTokensFromTextgenAPI(str);
|
||||
default: {
|
||||
const endpointUrl = TOKENIZER_URLS[type]?.count;
|
||||
if (!endpointUrl) {
|
||||
console.warn('Unknown tokenizer type', type);
|
||||
return callTokenizer(tokenizers.NONE, str, padding);
|
||||
return apiFailureTokenCount(str);
|
||||
}
|
||||
return countTokensFromServer(endpointUrl, str);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -219,7 +285,7 @@ export function getTokenCount(str, padding = undefined) {
|
|||
return cacheObject[cacheKey];
|
||||
}
|
||||
|
||||
const result = callTokenizer(tokenizerType, str, padding);
|
||||
const result = callTokenizer(tokenizerType, str) + padding;
|
||||
|
||||
if (isNaN(result)) {
|
||||
console.warn('Token count calculation returned NaN');
|
||||
|
@ -391,40 +457,103 @@ function getTokenCacheObject() {
|
|||
return tokenCache[String(chatId)];
|
||||
}
|
||||
|
||||
function getRemoteTokenizationParams(str) {
|
||||
return {
|
||||
text: str,
|
||||
main_api,
|
||||
api_type: textgen_settings.type,
|
||||
url: getAPIServerUrl(),
|
||||
legacy_api: main_api === 'textgenerationwebui' &&
|
||||
textgen_settings.legacy_api &&
|
||||
textgen_settings.type !== MANCER,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Counts token using the remote server API.
|
||||
* Count tokens using the server API.
|
||||
* @param {string} endpoint API endpoint.
|
||||
* @param {string} str String to tokenize.
|
||||
* @param {number} padding Number of padding tokens.
|
||||
* @returns {number} Token count with padding.
|
||||
* @returns {number} Token count.
|
||||
*/
|
||||
function countTokensRemote(endpoint, str, padding) {
|
||||
function countTokensFromServer(endpoint, str) {
|
||||
let tokenCount = 0;
|
||||
|
||||
jQuery.ajax({
|
||||
async: false,
|
||||
type: 'POST',
|
||||
url: endpoint,
|
||||
data: JSON.stringify(getRemoteTokenizationParams(str)),
|
||||
data: JSON.stringify({ text: str }),
|
||||
dataType: 'json',
|
||||
contentType: 'application/json',
|
||||
success: function (data) {
|
||||
if (typeof data.count === 'number') {
|
||||
tokenCount = data.count;
|
||||
} else {
|
||||
tokenCount = guesstimate(str);
|
||||
tokenCount = apiFailureTokenCount(str);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
/**
|
||||
* Count tokens using the AI provider's API.
|
||||
* @param {string} str String to tokenize.
|
||||
* @returns {number} Token count.
|
||||
*/
|
||||
function countTokensFromKoboldAPI(str) {
|
||||
let tokenCount = 0;
|
||||
|
||||
jQuery.ajax({
|
||||
async: false,
|
||||
type: 'POST',
|
||||
url: TOKENIZER_URLS[tokenizers.API_KOBOLD].count,
|
||||
data: JSON.stringify({
|
||||
text: str,
|
||||
url: api_server,
|
||||
}),
|
||||
dataType: 'json',
|
||||
contentType: 'application/json',
|
||||
success: function (data) {
|
||||
if (typeof data.count === 'number') {
|
||||
tokenCount = data.count;
|
||||
} else {
|
||||
tokenCount = apiFailureTokenCount(str);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
function getTextgenAPITokenizationParams(str) {
|
||||
return {
|
||||
text: str,
|
||||
api_type: textgen_settings.type,
|
||||
url: api_server_textgenerationwebui,
|
||||
legacy_api:
|
||||
textgen_settings.legacy_api &&
|
||||
textgen_settings.type !== MANCER,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Count tokens using the AI provider's API.
|
||||
* @param {string} str String to tokenize.
|
||||
* @returns {number} Token count.
|
||||
*/
|
||||
function countTokensFromTextgenAPI(str) {
|
||||
let tokenCount = 0;
|
||||
|
||||
jQuery.ajax({
|
||||
async: false,
|
||||
type: 'POST',
|
||||
url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].count,
|
||||
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
|
||||
dataType: 'json',
|
||||
contentType: 'application/json',
|
||||
success: function (data) {
|
||||
if (typeof data.count === 'number') {
|
||||
tokenCount = data.count;
|
||||
} else {
|
||||
tokenCount = apiFailureTokenCount(str);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
function apiFailureTokenCount(str) {
|
||||
console.error('Error counting tokens');
|
||||
|
||||
if (!sessionStorage.getItem(TOKENIZER_WARNING_KEY)) {
|
||||
|
@ -436,31 +565,23 @@ function countTokensRemote(endpoint, str, padding) {
|
|||
|
||||
sessionStorage.setItem(TOKENIZER_WARNING_KEY, String(true));
|
||||
}
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
return tokenCount + padding;
|
||||
return guesstimate(str);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calls the underlying tokenizer model to encode a string to tokens.
|
||||
* @param {string} endpoint API endpoint.
|
||||
* @param {string} str String to tokenize.
|
||||
* @param {string} model Tokenizer model.
|
||||
* @returns {number[]} Array of token ids.
|
||||
*/
|
||||
function getTextTokensRemote(endpoint, str, model = '') {
|
||||
if (model) {
|
||||
endpoint += `?model=${model}`;
|
||||
}
|
||||
|
||||
function getTextTokensFromServer(endpoint, str) {
|
||||
let ids = [];
|
||||
jQuery.ajax({
|
||||
async: false,
|
||||
type: 'POST',
|
||||
url: endpoint,
|
||||
data: JSON.stringify(getRemoteTokenizationParams(str)),
|
||||
data: JSON.stringify({ text: str }),
|
||||
dataType: 'json',
|
||||
contentType: 'application/json',
|
||||
success: function (data) {
|
||||
|
@ -475,16 +596,33 @@ function getTextTokensRemote(endpoint, str, model = '') {
|
|||
return ids;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calls the AI provider's tokenize API to encode a string to tokens.
|
||||
* @param {string} str String to tokenize.
|
||||
* @returns {number[]} Array of token ids.
|
||||
*/
|
||||
function getTextTokensFromTextgenAPI(str) {
|
||||
let ids = [];
|
||||
jQuery.ajax({
|
||||
async: false,
|
||||
type: 'POST',
|
||||
url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].encode,
|
||||
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
|
||||
dataType: 'json',
|
||||
contentType: 'application/json',
|
||||
success: function (data) {
|
||||
ids = data.ids;
|
||||
},
|
||||
});
|
||||
return ids;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calls the underlying tokenizer model to decode token ids to text.
|
||||
* @param {string} endpoint API endpoint.
|
||||
* @param {number[]} ids Array of token ids
|
||||
*/
|
||||
function decodeTextTokensRemote(endpoint, ids, model = '') {
|
||||
if (model) {
|
||||
endpoint += `?model=${model}`;
|
||||
}
|
||||
|
||||
function decodeTextTokensFromServer(endpoint, ids) {
|
||||
let text = '';
|
||||
jQuery.ajax({
|
||||
async: false,
|
||||
|
@ -501,64 +639,62 @@ function decodeTextTokensRemote(endpoint, ids, model = '') {
|
|||
}
|
||||
|
||||
/**
|
||||
* Encodes a string to tokens using the remote server API.
|
||||
* Encodes a string to tokens using the server API.
|
||||
* @param {number} tokenizerType Tokenizer type.
|
||||
* @param {string} str String to tokenize.
|
||||
* @returns {number[]} Array of token ids.
|
||||
*/
|
||||
export function getTextTokens(tokenizerType, str) {
|
||||
switch (tokenizerType) {
|
||||
case tokenizers.GPT2:
|
||||
return getTextTokensRemote('/api/tokenizers/gpt2/encode', str);
|
||||
case tokenizers.LLAMA:
|
||||
return getTextTokensRemote('/api/tokenizers/llama/encode', str);
|
||||
case tokenizers.NERD:
|
||||
return getTextTokensRemote('/api/tokenizers/nerdstash/encode', str);
|
||||
case tokenizers.NERD2:
|
||||
return getTextTokensRemote('/api/tokenizers/nerdstash_v2/encode', str);
|
||||
case tokenizers.MISTRAL:
|
||||
return getTextTokensRemote('/api/tokenizers/mistral/encode', str);
|
||||
case tokenizers.YI:
|
||||
return getTextTokensRemote('/api/tokenizers/yi/encode', str);
|
||||
case tokenizers.OPENAI: {
|
||||
const model = getTokenizerModel();
|
||||
return getTextTokensRemote('/api/tokenizers/openai/encode', str, model);
|
||||
}
|
||||
case tokenizers.API:
|
||||
return getTextTokensRemote('/tokenize_via_api', str);
|
||||
default:
|
||||
console.warn('Calling getTextTokens with unsupported tokenizer type', tokenizerType);
|
||||
case tokenizers.API_CURRENT:
|
||||
return getTextTokens(currentRemoteTokenizerAPI(), str);
|
||||
case tokenizers.API_TEXTGENERATIONWEBUI:
|
||||
return getTextTokensFromTextgenAPI(str);
|
||||
default: {
|
||||
const tokenizerEndpoints = TOKENIZER_URLS[tokenizerType];
|
||||
if (!tokenizerEndpoints) {
|
||||
apiFailureTokenCount(str);
|
||||
console.warn('Unknown tokenizer type', tokenizerType);
|
||||
return [];
|
||||
}
|
||||
let endpointUrl = tokenizerEndpoints.encode;
|
||||
if (!endpointUrl) {
|
||||
apiFailureTokenCount(str);
|
||||
console.warn('This tokenizer type does not support encoding', tokenizerType);
|
||||
return [];
|
||||
}
|
||||
if (tokenizerType === tokenizers.OPENAI) {
|
||||
endpointUrl += `?model=${getTokenizerModel()}`;
|
||||
}
|
||||
return getTextTokensFromServer(endpointUrl, str);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Decodes token ids to text using the remote server API.
|
||||
* Decodes token ids to text using the server API.
|
||||
* @param {number} tokenizerType Tokenizer type.
|
||||
* @param {number[]} ids Array of token ids
|
||||
*/
|
||||
export function decodeTextTokens(tokenizerType, ids) {
|
||||
switch (tokenizerType) {
|
||||
case tokenizers.GPT2:
|
||||
return decodeTextTokensRemote('/api/tokenizers/gpt2/decode', ids);
|
||||
case tokenizers.LLAMA:
|
||||
return decodeTextTokensRemote('/api/tokenizers/llama/decode', ids);
|
||||
case tokenizers.NERD:
|
||||
return decodeTextTokensRemote('/api/tokenizers/nerdstash/decode', ids);
|
||||
case tokenizers.NERD2:
|
||||
return decodeTextTokensRemote('/api/tokenizers/nerdstash_v2/decode', ids);
|
||||
case tokenizers.MISTRAL:
|
||||
return decodeTextTokensRemote('/api/tokenizers/mistral/decode', ids);
|
||||
case tokenizers.YI:
|
||||
return decodeTextTokensRemote('/api/tokenizers/yi/decode', ids);
|
||||
case tokenizers.OPENAI: {
|
||||
const model = getTokenizerModel();
|
||||
return decodeTextTokensRemote('/api/tokenizers/openai/decode', ids, model);
|
||||
// Currently, neither remote API can decode, but this may change in the future. Put this guard here to be safe
|
||||
if (tokenizerType === tokenizers.API_CURRENT) {
|
||||
return decodeTextTokens(tokenizers.NONE, ids);
|
||||
}
|
||||
default:
|
||||
console.warn('Calling decodeTextTokens with unsupported tokenizer type', tokenizerType);
|
||||
return '';
|
||||
const tokenizerEndpoints = TOKENIZER_URLS[tokenizerType];
|
||||
if (!tokenizerEndpoints) {
|
||||
console.warn('Unknown tokenizer type', tokenizerType);
|
||||
return [];
|
||||
}
|
||||
let endpointUrl = tokenizerEndpoints.decode;
|
||||
if (!endpointUrl) {
|
||||
console.warn('This tokenizer type does not support decoding', tokenizerType);
|
||||
return [];
|
||||
}
|
||||
if (tokenizerType === tokenizers.OPENAI) {
|
||||
endpointUrl += `?model=${getTokenizerModel()}`;
|
||||
}
|
||||
return decodeTextTokensFromServer(endpointUrl, ids);
|
||||
}
|
||||
|
||||
export async function initTokenizers() {
|
||||
|
|
152
server.js
152
server.js
|
@ -49,6 +49,7 @@ const { delay, getVersion, getConfigValue, color, uuidv4, tryParse, clientRelati
|
|||
const { ensureThumbnailCache } = require('./src/endpoints/thumbnails');
|
||||
const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS, getSentencepiceTokenizer, sentencepieceTokenizers } = require('./src/endpoints/tokenizers');
|
||||
const { convertClaudePrompt } = require('./src/chat-completion');
|
||||
const { getOverrideHeaders, setAdditionalHeaders } = require('./src/additional-headers');
|
||||
|
||||
// Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0.
|
||||
// https://github.com/nodejs/node/issues/47822#issuecomment-1564708870
|
||||
|
@ -119,70 +120,6 @@ const listen = getConfigValue('listen', false);
|
|||
const API_OPENAI = 'https://api.openai.com/v1';
|
||||
const API_CLAUDE = 'https://api.anthropic.com/v1';
|
||||
|
||||
function getMancerHeaders() {
|
||||
const apiKey = readSecret(SECRET_KEYS.MANCER);
|
||||
|
||||
return apiKey ? ({
|
||||
'X-API-KEY': apiKey,
|
||||
'Authorization': `Bearer ${apiKey}`,
|
||||
}) : {};
|
||||
}
|
||||
|
||||
function getAphroditeHeaders() {
|
||||
const apiKey = readSecret(SECRET_KEYS.APHRODITE);
|
||||
|
||||
return apiKey ? ({
|
||||
'X-API-KEY': apiKey,
|
||||
'Authorization': `Bearer ${apiKey}`,
|
||||
}) : {};
|
||||
}
|
||||
|
||||
function getTabbyHeaders() {
|
||||
const apiKey = readSecret(SECRET_KEYS.TABBY);
|
||||
|
||||
return apiKey ? ({
|
||||
'x-api-key': apiKey,
|
||||
'Authorization': `Bearer ${apiKey}`,
|
||||
}) : {};
|
||||
}
|
||||
|
||||
function getOverrideHeaders(urlHost) {
|
||||
const requestOverrides = getConfigValue('requestOverrides', []);
|
||||
const overrideHeaders = requestOverrides?.find((e) => e.hosts?.includes(urlHost))?.headers;
|
||||
if (overrideHeaders && urlHost) {
|
||||
return overrideHeaders;
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets additional headers for the request.
|
||||
* @param {object} request Original request body
|
||||
* @param {object} args New request arguments
|
||||
* @param {string|null} server API server for new request
|
||||
*/
|
||||
function setAdditionalHeaders(request, args, server) {
|
||||
let headers;
|
||||
|
||||
switch (request.body.api_type) {
|
||||
case TEXTGEN_TYPES.MANCER:
|
||||
headers = getMancerHeaders();
|
||||
break;
|
||||
case TEXTGEN_TYPES.APHRODITE:
|
||||
headers = getAphroditeHeaders();
|
||||
break;
|
||||
case TEXTGEN_TYPES.TABBY:
|
||||
headers = getTabbyHeaders();
|
||||
break;
|
||||
default:
|
||||
headers = server ? getOverrideHeaders((new URL(server))?.host) : {};
|
||||
break;
|
||||
}
|
||||
|
||||
Object.assign(args.headers, headers);
|
||||
}
|
||||
|
||||
const SETTINGS_FILE = './public/settings.json';
|
||||
const { DIRECTORIES, UPLOADS_PATH, PALM_SAFETY, TEXTGEN_TYPES, CHAT_COMPLETION_SOURCES, AVATAR_WIDTH, AVATAR_HEIGHT } = require('./src/constants');
|
||||
|
||||
|
@ -1774,93 +1711,6 @@ async function sendAI21Request(request, response) {
|
|||
|
||||
}
|
||||
|
||||
app.post('/tokenize_via_api', jsonParser, async function (request, response) {
|
||||
if (!request.body) {
|
||||
return response.sendStatus(400);
|
||||
}
|
||||
const text = String(request.body.text) || '';
|
||||
const api = String(request.body.main_api);
|
||||
const baseUrl = String(request.body.url);
|
||||
const legacyApi = Boolean(request.body.legacy_api);
|
||||
|
||||
try {
|
||||
if (api == 'textgenerationwebui') {
|
||||
const args = {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
};
|
||||
|
||||
setAdditionalHeaders(request, args, null);
|
||||
|
||||
// Convert to string + remove trailing slash + /v1 suffix
|
||||
let url = String(baseUrl).replace(/\/$/, '').replace(/\/v1$/, '');
|
||||
|
||||
if (legacyApi) {
|
||||
url += '/v1/token-count';
|
||||
args.body = JSON.stringify({ 'prompt': text });
|
||||
} else {
|
||||
switch (request.body.api_type) {
|
||||
case TEXTGEN_TYPES.TABBY:
|
||||
url += '/v1/token/encode';
|
||||
args.body = JSON.stringify({ 'text': text });
|
||||
break;
|
||||
case TEXTGEN_TYPES.KOBOLDCPP:
|
||||
url += '/api/extra/tokencount';
|
||||
args.body = JSON.stringify({ 'prompt': text });
|
||||
break;
|
||||
default:
|
||||
url += '/v1/internal/encode';
|
||||
args.body = JSON.stringify({ 'text': text });
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const result = await fetch(url, args);
|
||||
|
||||
if (!result.ok) {
|
||||
console.log(`API returned error: ${result.status} ${result.statusText}`);
|
||||
return response.send({ error: true });
|
||||
}
|
||||
|
||||
const data = await result.json();
|
||||
const count = legacyApi ? data?.results[0]?.tokens : (data?.length ?? data?.value);
|
||||
const ids = legacyApi ? [] : (data?.tokens ?? []);
|
||||
|
||||
return response.send({ count, ids });
|
||||
}
|
||||
|
||||
else if (api == 'kobold') {
|
||||
const args = {
|
||||
method: 'POST',
|
||||
body: JSON.stringify({ 'prompt': text }),
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
};
|
||||
|
||||
let url = String(baseUrl).replace(/\/$/, '');
|
||||
url += '/extra/tokencount';
|
||||
|
||||
const result = await fetch(url, args);
|
||||
|
||||
if (!result.ok) {
|
||||
console.log(`API returned error: ${result.status} ${result.statusText}`);
|
||||
return response.send({ error: true });
|
||||
}
|
||||
|
||||
const data = await result.json();
|
||||
const count = data['value'];
|
||||
return response.send({ count: count, ids: [] });
|
||||
}
|
||||
|
||||
else {
|
||||
console.log('Unknown API', api);
|
||||
return response.send({ error: true });
|
||||
}
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
return response.send({ error: true });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Redirect a deprecated API endpoint URL to its replacement. Because fetch, form submissions, and $.ajax follow
|
||||
* redirects, this is transparent to client-side code.
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
const { TEXTGEN_TYPES } = require('./constants');
|
||||
const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
|
||||
const { getConfigValue } = require('./util');
|
||||
|
||||
function getMancerHeaders() {
|
||||
const apiKey = readSecret(SECRET_KEYS.MANCER);
|
||||
|
||||
return apiKey ? ({
|
||||
'X-API-KEY': apiKey,
|
||||
'Authorization': `Bearer ${apiKey}`,
|
||||
}) : {};
|
||||
}
|
||||
|
||||
function getAphroditeHeaders() {
|
||||
const apiKey = readSecret(SECRET_KEYS.APHRODITE);
|
||||
|
||||
return apiKey ? ({
|
||||
'X-API-KEY': apiKey,
|
||||
'Authorization': `Bearer ${apiKey}`,
|
||||
}) : {};
|
||||
}
|
||||
|
||||
function getTabbyHeaders() {
|
||||
const apiKey = readSecret(SECRET_KEYS.TABBY);
|
||||
|
||||
return apiKey ? ({
|
||||
'x-api-key': apiKey,
|
||||
'Authorization': `Bearer ${apiKey}`,
|
||||
}) : {};
|
||||
}
|
||||
|
||||
function getOverrideHeaders(urlHost) {
|
||||
const requestOverrides = getConfigValue('requestOverrides', []);
|
||||
const overrideHeaders = requestOverrides?.find((e) => e.hosts?.includes(urlHost))?.headers;
|
||||
if (overrideHeaders && urlHost) {
|
||||
return overrideHeaders;
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets additional headers for the request.
|
||||
* @param {object} request Original request body
|
||||
* @param {object} args New request arguments
|
||||
* @param {string|null} server API server for new request
|
||||
*/
|
||||
function setAdditionalHeaders(request, args, server) {
|
||||
let headers;
|
||||
|
||||
switch (request.body.api_type) {
|
||||
case TEXTGEN_TYPES.MANCER:
|
||||
headers = getMancerHeaders();
|
||||
break;
|
||||
case TEXTGEN_TYPES.APHRODITE:
|
||||
headers = getAphroditeHeaders();
|
||||
break;
|
||||
case TEXTGEN_TYPES.TABBY:
|
||||
headers = getTabbyHeaders();
|
||||
break;
|
||||
default:
|
||||
headers = server ? getOverrideHeaders((new URL(server))?.host) : {};
|
||||
break;
|
||||
}
|
||||
|
||||
Object.assign(args.headers, headers);
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getOverrideHeaders,
|
||||
setAdditionalHeaders,
|
||||
};
|
|
@ -6,7 +6,9 @@ const tiktoken = require('@dqbd/tiktoken');
|
|||
const { Tokenizer } = require('@agnai/web-tokenizers');
|
||||
const { convertClaudePrompt } = require('../chat-completion');
|
||||
const { readSecret, SECRET_KEYS } = require('./secrets');
|
||||
const { TEXTGEN_TYPES } = require('../constants');
|
||||
const { jsonParser } = require('../express-common');
|
||||
const { setAdditionalHeaders } = require('../additional-headers');
|
||||
|
||||
/**
|
||||
* @type {{[key: string]: import("@dqbd/tiktoken").Tiktoken}} Tokenizers cache
|
||||
|
@ -534,6 +536,96 @@ router.post('/openai/count', jsonParser, async function (req, res) {
|
|||
}
|
||||
});
|
||||
|
||||
router.post('/remote/kobold/count', jsonParser, async function (request, response) {
|
||||
if (!request.body) {
|
||||
return response.sendStatus(400);
|
||||
}
|
||||
const text = String(request.body.text) || '';
|
||||
const baseUrl = String(request.body.url);
|
||||
|
||||
try {
|
||||
const args = {
|
||||
method: 'POST',
|
||||
body: JSON.stringify({ 'prompt': text }),
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
};
|
||||
|
||||
let url = String(baseUrl).replace(/\/$/, '');
|
||||
url += '/extra/tokencount';
|
||||
|
||||
const result = await fetch(url, args);
|
||||
|
||||
if (!result.ok) {
|
||||
console.log(`API returned error: ${result.status} ${result.statusText}`);
|
||||
return response.send({ error: true });
|
||||
}
|
||||
|
||||
const data = await result.json();
|
||||
const count = data['value'];
|
||||
return response.send({ count, ids: [] });
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
return response.send({ error: true });
|
||||
}
|
||||
});
|
||||
|
||||
router.post('/remote/textgenerationwebui/encode', jsonParser, async function (request, response) {
|
||||
if (!request.body) {
|
||||
return response.sendStatus(400);
|
||||
}
|
||||
const text = String(request.body.text) || '';
|
||||
const baseUrl = String(request.body.url);
|
||||
const legacyApi = Boolean(request.body.legacy_api);
|
||||
|
||||
try {
|
||||
const args = {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
};
|
||||
|
||||
setAdditionalHeaders(request, args, null);
|
||||
|
||||
// Convert to string + remove trailing slash + /v1 suffix
|
||||
let url = String(baseUrl).replace(/\/$/, '').replace(/\/v1$/, '');
|
||||
|
||||
if (legacyApi) {
|
||||
url += '/v1/token-count';
|
||||
args.body = JSON.stringify({ 'prompt': text });
|
||||
} else {
|
||||
switch (request.body.api_type) {
|
||||
case TEXTGEN_TYPES.TABBY:
|
||||
url += '/v1/token/encode';
|
||||
args.body = JSON.stringify({ 'text': text });
|
||||
break;
|
||||
case TEXTGEN_TYPES.KOBOLDCPP:
|
||||
url += '/api/extra/tokencount';
|
||||
args.body = JSON.stringify({ 'prompt': text });
|
||||
break;
|
||||
default:
|
||||
url += '/v1/internal/encode';
|
||||
args.body = JSON.stringify({ 'text': text });
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const result = await fetch(url, args);
|
||||
|
||||
if (!result.ok) {
|
||||
console.log(`API returned error: ${result.status} ${result.statusText}`);
|
||||
return response.send({ error: true });
|
||||
}
|
||||
|
||||
const data = await result.json();
|
||||
const count = legacyApi ? data?.results[0]?.tokens : (data?.length ?? data?.value);
|
||||
const ids = legacyApi ? [] : (data?.tokens ?? []);
|
||||
|
||||
return response.send({ count, ids });
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
return response.send({ error: true });
|
||||
}
|
||||
});
|
||||
|
||||
module.exports = {
|
||||
TEXT_COMPLETION_MODELS,
|
||||
getTokenizerModel,
|
||||
|
|
Loading…
Reference in New Issue