Merge branch 'staging' into qr-editor-tab-support

This commit is contained in:
Cohee 2023-12-10 16:36:28 +02:00
commit 5054de247b
8 changed files with 556 additions and 378 deletions

View File

@ -232,7 +232,6 @@ export {
isStreamingEnabled,
getThumbnailUrl,
getStoppingStrings,
getStatus,
reloadMarkdownProcessor,
getCurrentChatId,
chat,
@ -526,14 +525,17 @@ function getUrlSync(url, cache = true) {
}).responseText;
}
const templateCache = {};
const templateCache = new Map();
export function renderTemplate(templateId, templateData = {}, sanitize = true, localize = true, fullPath = false) {
try {
const pathToTemplate = fullPath ? templateId : `/scripts/templates/${templateId}.html`;
const templateContent = (pathToTemplate in templateCache) ? templateCache[pathToTemplate] : getUrlSync(pathToTemplate);
templateCache[pathToTemplate] = templateContent;
const template = Handlebars.compile(templateContent);
let template = templateCache.get(pathToTemplate);
if (!template) {
const templateContent = getUrlSync(pathToTemplate);
template = Handlebars.compile(templateContent);
templateCache.set(pathToTemplate, template);
}
let result = template(templateData);
if (sanitize) {
@ -857,7 +859,7 @@ export async function clearItemizedPrompts() {
}
}
async function getStatus() {
async function getStatusKobold() {
if (main_api == 'koboldhorde') {
try {
const hordeStatus = await checkHordeStatus();
@ -870,9 +872,9 @@ async function getStatus() {
return resultCheckStatus();
}
const url = main_api == 'textgenerationwebui' ? '/api/textgenerationwebui/status' : '/getstatus';
const url = '/getstatus';
let endpoint = getAPIServerUrl();
let endpoint = api_server;
if (!endpoint) {
console.warn('No endpoint for status check');
@ -886,18 +888,66 @@ async function getStatus() {
body: JSON.stringify({
main_api,
api_server: endpoint,
api_type: textgen_settings.type,
legacy_api: main_api == 'textgenerationwebui' ?
textgen_settings.legacy_api &&
textgen_settings.type !== MANCER :
false,
}),
signal: abortStatusCheck.signal,
});
const data = await response.json();
if (main_api == 'textgenerationwebui' && textgen_settings.type === MANCER) {
online_status = data?.result;
if (!online_status) {
online_status = 'no_connection';
}
// Determine instruct mode preset
autoSelectInstructPreset(online_status);
// determine if we can use stop sequence and streaming
setKoboldFlags(data.version, data.koboldVersion);
// We didn't get a 200 status code, but the endpoint has an explanation. Which means it DID connect, but I digress.
if (online_status === 'no_connection' && data.response) {
toastr.error(data.response, 'API Error', { timeOut: 5000, preventDuplicates: true });
}
} catch (err) {
console.error('Error getting status', err);
online_status = 'no_connection';
}
return resultCheckStatus();
}
async function getStatusTextgen() {
const url = '/api/textgenerationwebui/status';
let endpoint = textgen_settings.type === MANCER ?
MANCER_SERVER :
api_server_textgenerationwebui;
if (!endpoint) {
console.warn('No endpoint for status check');
return;
}
try {
const response = await fetch(url, {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({
api_server: endpoint,
api_type: textgen_settings.type,
legacy_api:
textgen_settings.legacy_api &&
textgen_settings.type !== MANCER,
}),
signal: abortStatusCheck.signal,
});
const data = await response.json();
if (textgen_settings.type === MANCER) {
online_status = textgen_settings.mancer_model;
loadMancerModels(data?.data);
} else {
@ -911,11 +961,6 @@ async function getStatus() {
// Determine instruct mode preset
autoSelectInstructPreset(online_status);
// determine if we can use stop sequence and streaming
if (main_api === 'kobold' || main_api === 'koboldhorde') {
setKoboldFlags(data.version, data.koboldVersion);
}
// We didn't get a 200 status code, but the endpoint has an explanation. Which means it DID connect, but I digress.
if (online_status === 'no_connection' && data.response) {
toastr.error(data.response, 'API Error', { timeOut: 5000, preventDuplicates: true });
@ -928,6 +973,22 @@ async function getStatus() {
return resultCheckStatus();
}
async function getStatusNovel() {
try {
const result = await loadNovelSubscriptionData();
if (!result) {
throw new Error('Could not load subscription data');
}
online_status = getNovelTier();
} catch {
online_status = 'no_connection';
}
resultCheckStatus();
}
export function startStatusLoading() {
$('.api_loading').show();
$('.api_button').addClass('disabled');
@ -943,22 +1004,6 @@ export function resultCheckStatus() {
stopStatusLoading();
}
export function getAPIServerUrl() {
if (main_api == 'textgenerationwebui') {
if (textgen_settings.type === MANCER) {
return MANCER_SERVER;
}
return api_server_textgenerationwebui;
}
if (main_api == 'kobold') {
return api_server;
}
return '';
}
export async function selectCharacterById(id) {
if (characters[id] == undefined) {
return;
@ -1494,7 +1539,7 @@ function messageFormatting(mes, ch_name, isSystem, isUser) {
mes = mes.replace(new RegExp(`(^|\n)${ch_name}:`, 'g'), '$1');
}
mes = DOMPurify.sanitize(mes);
mes = DOMPurify.sanitize(mes, { FORBID_TAGS: ['style'] });
return mes;
}
@ -5314,7 +5359,7 @@ function changeMainAPI() {
}
if (main_api == 'koboldhorde') {
getStatus();
getStatusKobold();
getHordeModels();
}
@ -6031,22 +6076,6 @@ export async function displayPastChats() {
});
}
async function getStatusNovel() {
try {
const result = await loadNovelSubscriptionData();
if (!result) {
throw new Error('Could not load subscription data');
}
online_status = getNovelTier();
} catch {
online_status = 'no_connection';
}
resultCheckStatus();
}
function selectRightMenuWithAnimation(selectedMenuId) {
const displayModes = {
'rm_group_chats_block': 'flex',
@ -8268,7 +8297,7 @@ jQuery(async function () {
main_api = 'kobold';
saveSettingsDebounced();
getStatus();
getStatusKobold();
}
});
@ -8304,7 +8333,25 @@ jQuery(async function () {
startStatusLoading();
main_api = 'textgenerationwebui';
saveSettingsDebounced();
getStatus();
getStatusTextgen();
});
$('#api_button_novel').on('click', async function (e) {
e.stopPropagation();
const api_key_novel = String($('#api_key_novel').val()).trim();
if (api_key_novel.length) {
await writeSecret(SECRET_KEYS.NOVEL, api_key_novel);
}
if (!secret_state[SECRET_KEYS.NOVEL]) {
console.log('No secret key saved for NovelAI');
return;
}
startStatusLoading();
// Check near immediately rather than waiting for up to 90s
await getStatusNovel();
});
var button = $('#options_button');
@ -8993,24 +9040,6 @@ jQuery(async function () {
});
//Select chat
$('#api_button_novel').on('click', async function (e) {
e.stopPropagation();
const api_key_novel = String($('#api_key_novel').val()).trim();
if (api_key_novel.length) {
await writeSecret(SECRET_KEYS.NOVEL, api_key_novel);
}
if (!secret_state[SECRET_KEYS.NOVEL]) {
console.log('No secret key saved for NovelAI');
return;
}
startStatusLoading();
// Check near immediately rather than waiting for up to 90s
await getStatusNovel();
});
//**************************CHARACTER IMPORT EXPORT*************************//
$('#character_import_button').click(function () {
$('#character_import_file').click();

View File

@ -902,7 +902,7 @@ export function initRossMods() {
const chatBlock = $('#chat');
const originalScrollBottom = chatBlock[0].scrollHeight - (chatBlock.scrollTop() + chatBlock.outerHeight());
this.style.height = window.getComputedStyle(this).getPropertyValue('min-height');
this.style.height = this.scrollHeight + 0.1 + 'px';
this.style.height = this.scrollHeight + 0.3 + 'px';
if (!isFirefox) {
const newScrollTop = Math.round(chatBlock[0].scrollHeight - (chatBlock.outerHeight() + originalScrollBottom));

View File

@ -35,7 +35,7 @@ import { registerSlashCommand } from './slash-commands.js';
import { tags } from './tags.js';
import { tokenizers } from './tokenizers.js';
import { countOccurrences, debounce, delay, isOdd, resetScrollHeight, sortMoments, stringToRange, timestampToMoment } from './utils.js';
import { countOccurrences, debounce, delay, isOdd, resetScrollHeight, shuffle, sortMoments, stringToRange, timestampToMoment } from './utils.js';
export {
loadPowerUserSettings,
@ -1818,10 +1818,6 @@ export function renderStoryString(params) {
const sortFunc = (a, b) => power_user.sort_order == 'asc' ? compareFunc(a, b) : compareFunc(b, a);
const compareFunc = (first, second) => {
if (power_user.sort_order == 'random') {
return Math.random() > 0.5 ? 1 : -1;
}
const a = first[power_user.sort_field];
const b = second[power_user.sort_field];
@ -1853,6 +1849,11 @@ function sortEntitiesList(entities) {
return;
}
if (power_user.sort_order === 'random') {
shuffle(entities);
return;
}
entities.sort((a, b) => {
if (a.type === 'tag' && b.type !== 'tag') {
return -1;

View File

@ -1,4 +1,4 @@
import { characters, getAPIServerUrl, main_api, nai_settings, online_status, this_chid } from '../script.js';
import { characters, main_api, api_server, api_server_textgenerationwebui, nai_settings, online_status, this_chid } from '../script.js';
import { power_user, registerDebugFunction } from './power-user.js';
import { chat_completion_sources, model_list, oai_settings } from './openai.js';
import { groups, selected_group } from './group-chats.js';
@ -18,9 +18,11 @@ export const tokenizers = {
LLAMA: 3,
NERD: 4,
NERD2: 5,
API: 6,
API_CURRENT: 6,
MISTRAL: 7,
YI: 8,
API_TEXTGENERATIONWEBUI: 9,
API_KOBOLD: 10,
BEST_MATCH: 99,
};
@ -33,6 +35,51 @@ export const SENTENCEPIECE_TOKENIZERS = [
//tokenizers.NERD2,
];
const TOKENIZER_URLS = {
[tokenizers.GPT2]: {
encode: '/api/tokenizers/gpt2/encode',
decode: '/api/tokenizers/gpt2/decode',
count: '/api/tokenizers/gpt2/encode',
},
[tokenizers.OPENAI]: {
encode: '/api/tokenizers/openai/encode',
decode: '/api/tokenizers/openai/decode',
count: '/api/tokenizers/openai/encode',
},
[tokenizers.LLAMA]: {
encode: '/api/tokenizers/llama/encode',
decode: '/api/tokenizers/llama/decode',
count: '/api/tokenizers/llama/encode',
},
[tokenizers.NERD]: {
encode: '/api/tokenizers/nerdstash/encode',
decode: '/api/tokenizers/nerdstash/decode',
count: '/api/tokenizers/nerdstash/encode',
},
[tokenizers.NERD2]: {
encode: '/api/tokenizers/nerdstash_v2/encode',
decode: '/api/tokenizers/nerdstash_v2/decode',
count: '/api/tokenizers/nerdstash_v2/encode',
},
[tokenizers.API_KOBOLD]: {
count: '/api/tokenizers/remote/kobold/count',
},
[tokenizers.MISTRAL]: {
encode: '/api/tokenizers/mistral/encode',
decode: '/api/tokenizers/mistral/decode',
count: '/api/tokenizers/mistral/encode',
},
[tokenizers.YI]: {
encode: '/api/tokenizers/yi/encode',
decode: '/api/tokenizers/yi/decode',
count: '/api/tokenizers/yi/encode',
},
[tokenizers.API_TEXTGENERATIONWEBUI]: {
encode: '/api/tokenizers/remote/textgenerationwebui/encode',
count: '/api/tokenizers/remote/textgenerationwebui/encode',
},
};
const objectStore = new localforage.createInstance({ name: 'SillyTavern_ChatCompletions' });
let tokenCache = {};
@ -92,7 +139,18 @@ export function getFriendlyTokenizerName(forApi) {
if (forApi !== 'openai' && tokenizerId === tokenizers.BEST_MATCH) {
tokenizerId = getTokenizerBestMatch(forApi);
tokenizerName = $(`#tokenizer option[value="${tokenizerId}"]`).text();
switch (tokenizerId) {
case tokenizers.API_KOBOLD:
tokenizerName = 'API (KoboldAI Classic)';
break;
case tokenizers.API_TEXTGENERATIONWEBUI:
tokenizerName = 'API (Text Completion)';
break;
default:
tokenizerName = $(`#tokenizer option[value="${tokenizerId}"]`).text();
break;
}
}
tokenizerName = forApi == 'openai'
@ -135,11 +193,11 @@ export function getTokenizerBestMatch(forApi) {
if (!hasTokenizerError && isConnected) {
if (forApi === 'kobold' && kai_flags.can_use_tokenization) {
return tokenizers.API;
return tokenizers.API_KOBOLD;
}
if (forApi === 'textgenerationwebui' && isTokenizerSupported) {
return tokenizers.API;
return tokenizers.API_TEXTGENERATIONWEBUI;
}
}
@ -149,34 +207,42 @@ export function getTokenizerBestMatch(forApi) {
return tokenizers.NONE;
}
// Get the current remote tokenizer API based on the current text generation API.
function currentRemoteTokenizerAPI() {
switch (main_api) {
case 'kobold':
return tokenizers.API_KOBOLD;
case 'textgenerationwebui':
return tokenizers.API_TEXTGENERATIONWEBUI;
default:
return tokenizers.NONE;
}
}
/**
* Calls the underlying tokenizer model to the token count for a string.
* @param {number} type Tokenizer type.
* @param {string} str String to tokenize.
* @param {number} padding Number of padding tokens.
* @returns {number} Token count.
*/
function callTokenizer(type, str, padding) {
function callTokenizer(type, str) {
if (type === tokenizers.NONE) return guesstimate(str);
switch (type) {
case tokenizers.NONE:
return guesstimate(str) + padding;
case tokenizers.GPT2:
return countTokensRemote('/api/tokenizers/gpt2/encode', str, padding);
case tokenizers.LLAMA:
return countTokensRemote('/api/tokenizers/llama/encode', str, padding);
case tokenizers.NERD:
return countTokensRemote('/api/tokenizers/nerdstash/encode', str, padding);
case tokenizers.NERD2:
return countTokensRemote('/api/tokenizers/nerdstash_v2/encode', str, padding);
case tokenizers.MISTRAL:
return countTokensRemote('/api/tokenizers/mistral/encode', str, padding);
case tokenizers.YI:
return countTokensRemote('/api/tokenizers/yi/encode', str, padding);
case tokenizers.API:
return countTokensRemote('/tokenize_via_api', str, padding);
default:
console.warn('Unknown tokenizer type', type);
return callTokenizer(tokenizers.NONE, str, padding);
case tokenizers.API_CURRENT:
return callTokenizer(currentRemoteTokenizerAPI(), str);
case tokenizers.API_KOBOLD:
return countTokensFromKoboldAPI(str);
case tokenizers.API_TEXTGENERATIONWEBUI:
return countTokensFromTextgenAPI(str);
default: {
const endpointUrl = TOKENIZER_URLS[type]?.count;
if (!endpointUrl) {
console.warn('Unknown tokenizer type', type);
return apiFailureTokenCount(str);
}
return countTokensFromServer(endpointUrl, str);
}
}
}
@ -219,7 +285,7 @@ export function getTokenCount(str, padding = undefined) {
return cacheObject[cacheKey];
}
const result = callTokenizer(tokenizerType, str, padding);
const result = callTokenizer(tokenizerType, str) + padding;
if (isNaN(result)) {
console.warn('Token count calculation returned NaN');
@ -391,76 +457,131 @@ function getTokenCacheObject() {
return tokenCache[String(chatId)];
}
function getRemoteTokenizationParams(str) {
return {
text: str,
main_api,
api_type: textgen_settings.type,
url: getAPIServerUrl(),
legacy_api: main_api === 'textgenerationwebui' &&
textgen_settings.legacy_api &&
textgen_settings.type !== MANCER,
};
}
/**
* Counts token using the remote server API.
* Count tokens using the server API.
* @param {string} endpoint API endpoint.
* @param {string} str String to tokenize.
* @param {number} padding Number of padding tokens.
* @returns {number} Token count with padding.
* @returns {number} Token count.
*/
function countTokensRemote(endpoint, str, padding) {
function countTokensFromServer(endpoint, str) {
let tokenCount = 0;
jQuery.ajax({
async: false,
type: 'POST',
url: endpoint,
data: JSON.stringify(getRemoteTokenizationParams(str)),
data: JSON.stringify({ text: str }),
dataType: 'json',
contentType: 'application/json',
success: function (data) {
if (typeof data.count === 'number') {
tokenCount = data.count;
} else {
tokenCount = guesstimate(str);
console.error('Error counting tokens');
if (!sessionStorage.getItem(TOKENIZER_WARNING_KEY)) {
toastr.warning(
'Your selected API doesn\'t support the tokenization endpoint. Using estimated counts.',
'Error counting tokens',
{ timeOut: 10000, preventDuplicates: true },
);
sessionStorage.setItem(TOKENIZER_WARNING_KEY, String(true));
}
tokenCount = apiFailureTokenCount(str);
}
},
});
return tokenCount + padding;
return tokenCount;
}
/**
* Count tokens using the AI provider's API.
* @param {string} str String to tokenize.
* @returns {number} Token count.
*/
function countTokensFromKoboldAPI(str) {
let tokenCount = 0;
jQuery.ajax({
async: false,
type: 'POST',
url: TOKENIZER_URLS[tokenizers.API_KOBOLD].count,
data: JSON.stringify({
text: str,
url: api_server,
}),
dataType: 'json',
contentType: 'application/json',
success: function (data) {
if (typeof data.count === 'number') {
tokenCount = data.count;
} else {
tokenCount = apiFailureTokenCount(str);
}
},
});
return tokenCount;
}
function getTextgenAPITokenizationParams(str) {
return {
text: str,
api_type: textgen_settings.type,
url: api_server_textgenerationwebui,
legacy_api:
textgen_settings.legacy_api &&
textgen_settings.type !== MANCER,
};
}
/**
* Count tokens using the AI provider's API.
* @param {string} str String to tokenize.
* @returns {number} Token count.
*/
function countTokensFromTextgenAPI(str) {
let tokenCount = 0;
jQuery.ajax({
async: false,
type: 'POST',
url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].count,
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
dataType: 'json',
contentType: 'application/json',
success: function (data) {
if (typeof data.count === 'number') {
tokenCount = data.count;
} else {
tokenCount = apiFailureTokenCount(str);
}
},
});
return tokenCount;
}
function apiFailureTokenCount(str) {
console.error('Error counting tokens');
if (!sessionStorage.getItem(TOKENIZER_WARNING_KEY)) {
toastr.warning(
'Your selected API doesn\'t support the tokenization endpoint. Using estimated counts.',
'Error counting tokens',
{ timeOut: 10000, preventDuplicates: true },
);
sessionStorage.setItem(TOKENIZER_WARNING_KEY, String(true));
}
return guesstimate(str);
}
/**
* Calls the underlying tokenizer model to encode a string to tokens.
* @param {string} endpoint API endpoint.
* @param {string} str String to tokenize.
* @param {string} model Tokenizer model.
* @returns {number[]} Array of token ids.
*/
function getTextTokensRemote(endpoint, str, model = '') {
if (model) {
endpoint += `?model=${model}`;
}
function getTextTokensFromServer(endpoint, str) {
let ids = [];
jQuery.ajax({
async: false,
type: 'POST',
url: endpoint,
data: JSON.stringify(getRemoteTokenizationParams(str)),
data: JSON.stringify({ text: str }),
dataType: 'json',
contentType: 'application/json',
success: function (data) {
@ -475,16 +596,33 @@ function getTextTokensRemote(endpoint, str, model = '') {
return ids;
}
/**
* Calls the AI provider's tokenize API to encode a string to tokens.
* @param {string} str String to tokenize.
* @returns {number[]} Array of token ids.
*/
function getTextTokensFromTextgenAPI(str) {
let ids = [];
jQuery.ajax({
async: false,
type: 'POST',
url: TOKENIZER_URLS[tokenizers.API_TEXTGENERATIONWEBUI].encode,
data: JSON.stringify(getTextgenAPITokenizationParams(str)),
dataType: 'json',
contentType: 'application/json',
success: function (data) {
ids = data.ids;
},
});
return ids;
}
/**
* Calls the underlying tokenizer model to decode token ids to text.
* @param {string} endpoint API endpoint.
* @param {number[]} ids Array of token ids
*/
function decodeTextTokensRemote(endpoint, ids, model = '') {
if (model) {
endpoint += `?model=${model}`;
}
function decodeTextTokensFromServer(endpoint, ids) {
let text = '';
jQuery.ajax({
async: false,
@ -501,64 +639,62 @@ function decodeTextTokensRemote(endpoint, ids, model = '') {
}
/**
* Encodes a string to tokens using the remote server API.
* Encodes a string to tokens using the server API.
* @param {number} tokenizerType Tokenizer type.
* @param {string} str String to tokenize.
* @returns {number[]} Array of token ids.
*/
export function getTextTokens(tokenizerType, str) {
switch (tokenizerType) {
case tokenizers.GPT2:
return getTextTokensRemote('/api/tokenizers/gpt2/encode', str);
case tokenizers.LLAMA:
return getTextTokensRemote('/api/tokenizers/llama/encode', str);
case tokenizers.NERD:
return getTextTokensRemote('/api/tokenizers/nerdstash/encode', str);
case tokenizers.NERD2:
return getTextTokensRemote('/api/tokenizers/nerdstash_v2/encode', str);
case tokenizers.MISTRAL:
return getTextTokensRemote('/api/tokenizers/mistral/encode', str);
case tokenizers.YI:
return getTextTokensRemote('/api/tokenizers/yi/encode', str);
case tokenizers.OPENAI: {
const model = getTokenizerModel();
return getTextTokensRemote('/api/tokenizers/openai/encode', str, model);
case tokenizers.API_CURRENT:
return getTextTokens(currentRemoteTokenizerAPI(), str);
case tokenizers.API_TEXTGENERATIONWEBUI:
return getTextTokensFromTextgenAPI(str);
default: {
const tokenizerEndpoints = TOKENIZER_URLS[tokenizerType];
if (!tokenizerEndpoints) {
apiFailureTokenCount(str);
console.warn('Unknown tokenizer type', tokenizerType);
return [];
}
let endpointUrl = tokenizerEndpoints.encode;
if (!endpointUrl) {
apiFailureTokenCount(str);
console.warn('This tokenizer type does not support encoding', tokenizerType);
return [];
}
if (tokenizerType === tokenizers.OPENAI) {
endpointUrl += `?model=${getTokenizerModel()}`;
}
return getTextTokensFromServer(endpointUrl, str);
}
case tokenizers.API:
return getTextTokensRemote('/tokenize_via_api', str);
default:
console.warn('Calling getTextTokens with unsupported tokenizer type', tokenizerType);
return [];
}
}
/**
* Decodes token ids to text using the remote server API.
* Decodes token ids to text using the server API.
* @param {number} tokenizerType Tokenizer type.
* @param {number[]} ids Array of token ids
*/
export function decodeTextTokens(tokenizerType, ids) {
switch (tokenizerType) {
case tokenizers.GPT2:
return decodeTextTokensRemote('/api/tokenizers/gpt2/decode', ids);
case tokenizers.LLAMA:
return decodeTextTokensRemote('/api/tokenizers/llama/decode', ids);
case tokenizers.NERD:
return decodeTextTokensRemote('/api/tokenizers/nerdstash/decode', ids);
case tokenizers.NERD2:
return decodeTextTokensRemote('/api/tokenizers/nerdstash_v2/decode', ids);
case tokenizers.MISTRAL:
return decodeTextTokensRemote('/api/tokenizers/mistral/decode', ids);
case tokenizers.YI:
return decodeTextTokensRemote('/api/tokenizers/yi/decode', ids);
case tokenizers.OPENAI: {
const model = getTokenizerModel();
return decodeTextTokensRemote('/api/tokenizers/openai/decode', ids, model);
}
default:
console.warn('Calling decodeTextTokens with unsupported tokenizer type', tokenizerType);
return '';
// Currently, neither remote API can decode, but this may change in the future. Put this guard here to be safe
if (tokenizerType === tokenizers.API_CURRENT) {
return decodeTextTokens(tokenizers.NONE, ids);
}
const tokenizerEndpoints = TOKENIZER_URLS[tokenizerType];
if (!tokenizerEndpoints) {
console.warn('Unknown tokenizer type', tokenizerType);
return [];
}
let endpointUrl = tokenizerEndpoints.decode;
if (!endpointUrl) {
console.warn('This tokenizer type does not support decoding', tokenizerType);
return [];
}
if (tokenizerType === tokenizers.OPENAI) {
endpointUrl += `?model=${getTokenizerModel()}`;
}
return decodeTextTokensFromServer(endpointUrl, ids);
}
export async function initTokenizers() {

View File

@ -1444,9 +1444,7 @@ select option:not(:checked) {
display: block;
}
#api_button:hover,
#api_button_novel:hover,
#api_button_textgenerationwebui:hover {
.menu_button.api_button:hover {
background-color: var(--active);
}

220
server.js
View File

@ -49,6 +49,7 @@ const { delay, getVersion, getConfigValue, color, uuidv4, tryParse, clientRelati
const { ensureThumbnailCache } = require('./src/endpoints/thumbnails');
const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS, getSentencepiceTokenizer, sentencepieceTokenizers } = require('./src/endpoints/tokenizers');
const { convertClaudePrompt } = require('./src/chat-completion');
const { getOverrideHeaders, setAdditionalHeaders } = require('./src/additional-headers');
// Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0.
// https://github.com/nodejs/node/issues/47822#issuecomment-1564708870
@ -119,107 +120,9 @@ const listen = getConfigValue('listen', false);
const API_OPENAI = 'https://api.openai.com/v1';
const API_CLAUDE = 'https://api.anthropic.com/v1';
function getMancerHeaders() {
const apiKey = readSecret(SECRET_KEYS.MANCER);
return apiKey ? ({
'X-API-KEY': apiKey,
'Authorization': `Bearer ${apiKey}`,
}) : {};
}
function getAphroditeHeaders() {
const apiKey = readSecret(SECRET_KEYS.APHRODITE);
return apiKey ? ({
'X-API-KEY': apiKey,
'Authorization': `Bearer ${apiKey}`,
}) : {};
}
function getTabbyHeaders() {
const apiKey = readSecret(SECRET_KEYS.TABBY);
return apiKey ? ({
'x-api-key': apiKey,
'Authorization': `Bearer ${apiKey}`,
}) : {};
}
function getOverrideHeaders(urlHost) {
const requestOverrides = getConfigValue('requestOverrides', []);
const overrideHeaders = requestOverrides?.find((e) => e.hosts?.includes(urlHost))?.headers;
if (overrideHeaders && urlHost) {
return overrideHeaders;
} else {
return {};
}
}
/**
* Sets additional headers for the request.
* @param {object} request Original request body
* @param {object} args New request arguments
* @param {string|null} server API server for new request
*/
function setAdditionalHeaders(request, args, server) {
let headers;
switch (request.body.api_type) {
case TEXTGEN_TYPES.MANCER:
headers = getMancerHeaders();
break;
case TEXTGEN_TYPES.APHRODITE:
headers = getAphroditeHeaders();
break;
case TEXTGEN_TYPES.TABBY:
headers = getTabbyHeaders();
break;
default:
headers = server ? getOverrideHeaders((new URL(server))?.host) : {};
break;
}
Object.assign(args.headers, headers);
}
const SETTINGS_FILE = './public/settings.json';
const { DIRECTORIES, UPLOADS_PATH, PALM_SAFETY, TEXTGEN_TYPES, CHAT_COMPLETION_SOURCES, AVATAR_WIDTH, AVATAR_HEIGHT } = require('./src/constants');
// CSRF Protection //
if (!cliArguments.disableCsrf) {
const CSRF_SECRET = crypto.randomBytes(8).toString('hex');
const COOKIES_SECRET = crypto.randomBytes(8).toString('hex');
const { generateToken, doubleCsrfProtection } = doubleCsrf({
getSecret: () => CSRF_SECRET,
cookieName: 'X-CSRF-Token',
cookieOptions: {
httpOnly: true,
sameSite: 'strict',
secure: false,
},
size: 64,
getTokenFromRequest: (req) => req.headers['x-csrf-token'],
});
app.get('/csrf-token', (req, res) => {
res.json({
'token': generateToken(res, req),
});
});
app.use(cookieParser(COOKIES_SECRET));
app.use(doubleCsrfProtection);
} else {
console.warn('\nCSRF protection is disabled. This will make your server vulnerable to CSRF attacks.\n');
app.get('/csrf-token', (req, res) => {
res.json({
'token': 'disabled',
});
});
}
// CORS Settings //
const CORS = cors({
origin: 'null',
@ -273,6 +176,40 @@ app.use(function (req, res, next) {
next();
});
// CSRF Protection //
if (!cliArguments.disableCsrf) {
const CSRF_SECRET = crypto.randomBytes(8).toString('hex');
const COOKIES_SECRET = crypto.randomBytes(8).toString('hex');
const { generateToken, doubleCsrfProtection } = doubleCsrf({
getSecret: () => CSRF_SECRET,
cookieName: 'X-CSRF-Token',
cookieOptions: {
httpOnly: true,
sameSite: 'strict',
secure: false,
},
size: 64,
getTokenFromRequest: (req) => req.headers['x-csrf-token'],
});
app.get('/csrf-token', (req, res) => {
res.json({
'token': generateToken(res, req),
});
});
app.use(cookieParser(COOKIES_SECRET));
app.use(doubleCsrfProtection);
} else {
console.warn('\nCSRF protection is disabled. This will make your server vulnerable to CSRF attacks.\n');
app.get('/csrf-token', (req, res) => {
res.json({
'token': 'disabled',
});
});
}
if (getConfigValue('enableCorsProxy', false) || cliArguments.corsProxy) {
const bodyParser = require('body-parser');
app.use(bodyParser.json());
@ -1774,93 +1711,6 @@ async function sendAI21Request(request, response) {
}
app.post('/tokenize_via_api', jsonParser, async function (request, response) {
if (!request.body) {
return response.sendStatus(400);
}
const text = String(request.body.text) || '';
const api = String(request.body.main_api);
const baseUrl = String(request.body.url);
const legacyApi = Boolean(request.body.legacy_api);
try {
if (api == 'textgenerationwebui') {
const args = {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
};
setAdditionalHeaders(request, args, null);
// Convert to string + remove trailing slash + /v1 suffix
let url = String(baseUrl).replace(/\/$/, '').replace(/\/v1$/, '');
if (legacyApi) {
url += '/v1/token-count';
args.body = JSON.stringify({ 'prompt': text });
} else {
switch (request.body.api_type) {
case TEXTGEN_TYPES.TABBY:
url += '/v1/token/encode';
args.body = JSON.stringify({ 'text': text });
break;
case TEXTGEN_TYPES.KOBOLDCPP:
url += '/api/extra/tokencount';
args.body = JSON.stringify({ 'prompt': text });
break;
default:
url += '/v1/internal/encode';
args.body = JSON.stringify({ 'text': text });
break;
}
}
const result = await fetch(url, args);
if (!result.ok) {
console.log(`API returned error: ${result.status} ${result.statusText}`);
return response.send({ error: true });
}
const data = await result.json();
const count = legacyApi ? data?.results[0]?.tokens : (data?.length ?? data?.value);
const ids = legacyApi ? [] : (data?.tokens ?? []);
return response.send({ count, ids });
}
else if (api == 'kobold') {
const args = {
method: 'POST',
body: JSON.stringify({ 'prompt': text }),
headers: { 'Content-Type': 'application/json' },
};
let url = String(baseUrl).replace(/\/$/, '');
url += '/extra/tokencount';
const result = await fetch(url, args);
if (!result.ok) {
console.log(`API returned error: ${result.status} ${result.statusText}`);
return response.send({ error: true });
}
const data = await result.json();
const count = data['value'];
return response.send({ count: count, ids: [] });
}
else {
console.log('Unknown API', api);
return response.send({ error: true });
}
} catch (error) {
console.log(error);
return response.send({ error: true });
}
});
/**
* Redirect a deprecated API endpoint URL to its replacement. Because fetch, form submissions, and $.ajax follow
* redirects, this is transparent to client-side code.

72
src/additional-headers.js Normal file
View File

@ -0,0 +1,72 @@
const { TEXTGEN_TYPES } = require('./constants');
const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
const { getConfigValue } = require('./util');
function getMancerHeaders() {
const apiKey = readSecret(SECRET_KEYS.MANCER);
return apiKey ? ({
'X-API-KEY': apiKey,
'Authorization': `Bearer ${apiKey}`,
}) : {};
}
function getAphroditeHeaders() {
const apiKey = readSecret(SECRET_KEYS.APHRODITE);
return apiKey ? ({
'X-API-KEY': apiKey,
'Authorization': `Bearer ${apiKey}`,
}) : {};
}
function getTabbyHeaders() {
const apiKey = readSecret(SECRET_KEYS.TABBY);
return apiKey ? ({
'x-api-key': apiKey,
'Authorization': `Bearer ${apiKey}`,
}) : {};
}
function getOverrideHeaders(urlHost) {
const requestOverrides = getConfigValue('requestOverrides', []);
const overrideHeaders = requestOverrides?.find((e) => e.hosts?.includes(urlHost))?.headers;
if (overrideHeaders && urlHost) {
return overrideHeaders;
} else {
return {};
}
}
/**
* Sets additional headers for the request.
* @param {object} request Original request body
* @param {object} args New request arguments
* @param {string|null} server API server for new request
*/
function setAdditionalHeaders(request, args, server) {
let headers;
switch (request.body.api_type) {
case TEXTGEN_TYPES.MANCER:
headers = getMancerHeaders();
break;
case TEXTGEN_TYPES.APHRODITE:
headers = getAphroditeHeaders();
break;
case TEXTGEN_TYPES.TABBY:
headers = getTabbyHeaders();
break;
default:
headers = server ? getOverrideHeaders((new URL(server))?.host) : {};
break;
}
Object.assign(args.headers, headers);
}
module.exports = {
getOverrideHeaders,
setAdditionalHeaders,
};

View File

@ -6,7 +6,9 @@ const tiktoken = require('@dqbd/tiktoken');
const { Tokenizer } = require('@agnai/web-tokenizers');
const { convertClaudePrompt } = require('../chat-completion');
const { readSecret, SECRET_KEYS } = require('./secrets');
const { TEXTGEN_TYPES } = require('../constants');
const { jsonParser } = require('../express-common');
const { setAdditionalHeaders } = require('../additional-headers');
/**
* @type {{[key: string]: import("@dqbd/tiktoken").Tiktoken}} Tokenizers cache
@ -534,6 +536,96 @@ router.post('/openai/count', jsonParser, async function (req, res) {
}
});
router.post('/remote/kobold/count', jsonParser, async function (request, response) {
if (!request.body) {
return response.sendStatus(400);
}
const text = String(request.body.text) || '';
const baseUrl = String(request.body.url);
try {
const args = {
method: 'POST',
body: JSON.stringify({ 'prompt': text }),
headers: { 'Content-Type': 'application/json' },
};
let url = String(baseUrl).replace(/\/$/, '');
url += '/extra/tokencount';
const result = await fetch(url, args);
if (!result.ok) {
console.log(`API returned error: ${result.status} ${result.statusText}`);
return response.send({ error: true });
}
const data = await result.json();
const count = data['value'];
return response.send({ count, ids: [] });
} catch (error) {
console.log(error);
return response.send({ error: true });
}
});
router.post('/remote/textgenerationwebui/encode', jsonParser, async function (request, response) {
if (!request.body) {
return response.sendStatus(400);
}
const text = String(request.body.text) || '';
const baseUrl = String(request.body.url);
const legacyApi = Boolean(request.body.legacy_api);
try {
const args = {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
};
setAdditionalHeaders(request, args, null);
// Convert to string + remove trailing slash + /v1 suffix
let url = String(baseUrl).replace(/\/$/, '').replace(/\/v1$/, '');
if (legacyApi) {
url += '/v1/token-count';
args.body = JSON.stringify({ 'prompt': text });
} else {
switch (request.body.api_type) {
case TEXTGEN_TYPES.TABBY:
url += '/v1/token/encode';
args.body = JSON.stringify({ 'text': text });
break;
case TEXTGEN_TYPES.KOBOLDCPP:
url += '/api/extra/tokencount';
args.body = JSON.stringify({ 'prompt': text });
break;
default:
url += '/v1/internal/encode';
args.body = JSON.stringify({ 'text': text });
break;
}
}
const result = await fetch(url, args);
if (!result.ok) {
console.log(`API returned error: ${result.status} ${result.statusText}`);
return response.send({ error: true });
}
const data = await result.json();
const count = legacyApi ? data?.results[0]?.tokens : (data?.length ?? data?.value);
const ids = legacyApi ? [] : (data?.tokens ?? []);
return response.send({ count, ids });
} catch (error) {
console.log(error);
return response.send({ error: true });
}
});
module.exports = {
TEXT_COMPLETION_MODELS,
getTokenizerModel,