Merge branch 'vectors' into staging

This commit is contained in:
Cohee 2023-09-12 15:49:47 +03:00
commit 2f8f6844fe
28 changed files with 2822 additions and 506 deletions

2
.gitignore vendored
View File

@ -35,3 +35,5 @@ content.log
cloudflared.exe cloudflared.exe
public/assets/ public/assets/
access.log access.log
/vectors/
/cache/

View File

@ -15,7 +15,15 @@ const skipContentCheck = false; // If true, no new default content will be deliv
// Change this setting only on "trusted networks". Do not change this value unless you are aware of the issues that can arise from changing this setting and configuring a insecure setting. // Change this setting only on "trusted networks". Do not change this value unless you are aware of the issues that can arise from changing this setting and configuring a insecure setting.
const securityOverride = false; const securityOverride = false;
// Additional settings for extra modules / extensions
const extras = {
// Text classification model for sentiment analysis. HuggingFace ID of a model in ONNX format.
classificationModel: 'Cohee/distilbert-base-uncased-go-emotions-onnx',
};
// Request overrides for additional headers // Request overrides for additional headers
// Format is an array of objects:
// { hosts: [ "<url>" ], headers: { <header>: "<value>" } }
const requestOverrides = []; const requestOverrides = [];
module.exports = { module.exports = {
@ -32,4 +40,5 @@ module.exports = {
securityOverride, securityOverride,
skipContentCheck, skipContentCheck,
requestOverrides, requestOverrides,
extras,
}; };

1199
package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@ -3,6 +3,8 @@
"@agnai/sentencepiece-js": "^1.1.1", "@agnai/sentencepiece-js": "^1.1.1",
"@agnai/web-tokenizers": "^0.1.3", "@agnai/web-tokenizers": "^0.1.3",
"@dqbd/tiktoken": "^1.0.2", "@dqbd/tiktoken": "^1.0.2",
"@tensorflow-models/universal-sentence-encoder": "^1.3.3",
"@tensorflow/tfjs": "^4.10.0",
"command-exists": "^1.2.9", "command-exists": "^1.2.9",
"compression": "^1", "compression": "^1",
"cookie-parser": "^1.4.6", "cookie-parser": "^1.4.6",
@ -15,7 +17,7 @@
"gpt3-tokenizer": "^1.1.5", "gpt3-tokenizer": "^1.1.5",
"ip-matching": "^2.1.2", "ip-matching": "^2.1.2",
"ipaddr.js": "^2.0.1", "ipaddr.js": "^2.0.1",
"jimp": "^0.22.7", "jimp": "^0.22.10",
"jquery": "^3.6.4", "jquery": "^3.6.4",
"json5": "^2.2.3", "json5": "^2.2.3",
"lodash": "^4.17.21", "lodash": "^4.17.21",
@ -29,8 +31,10 @@
"png-chunks-extract": "^1.0.0", "png-chunks-extract": "^1.0.0",
"response-time": "^2.3.2", "response-time": "^2.3.2",
"sanitize-filename": "^1.6.3", "sanitize-filename": "^1.6.3",
"sillytavern-transformers": "^2.7.3",
"simple-git": "^3.19.1", "simple-git": "^3.19.1",
"uniqolor": "^1.1.0", "uniqolor": "^1.1.0",
"vectra": "^0.2.2",
"webp-converter": "2.3.2", "webp-converter": "2.3.2",
"write-file-atomic": "^5.0.1", "write-file-atomic": "^5.0.1",
"ws": "^8.13.0", "ws": "^8.13.0",
@ -49,7 +53,7 @@
"type": "git", "type": "git",
"url": "https://github.com/SillyTavern/SillyTavern.git" "url": "https://github.com/SillyTavern/SillyTavern.git"
}, },
"version": "1.10.2", "version": "1.10.3",
"scripts": { "scripts": {
"start": "node server.js", "start": "node server.js",
"start-multi": "node server.js --disableCsrf", "start-multi": "node server.js --disableCsrf",

View File

@ -121,7 +121,7 @@ import {
delay, delay,
restoreCaretPosition, restoreCaretPosition,
saveCaretPosition, saveCaretPosition,
end_trim_to_sentence, trimToEndSentence,
countOccurrences, countOccurrences,
isOdd, isOdd,
sortMoments, sortMoments,
@ -286,7 +286,9 @@ export const event_types = {
CHARACTER_EDITED: 'character_edited', CHARACTER_EDITED: 'character_edited',
USER_MESSAGE_RENDERED: 'user_message_rendered', USER_MESSAGE_RENDERED: 'user_message_rendered',
CHARACTER_MESSAGE_RENDERED: 'character_message_rendered', CHARACTER_MESSAGE_RENDERED: 'character_message_rendered',
FORCE_SET_BACKGROUND: 'force_set_background,' FORCE_SET_BACKGROUND: 'force_set_background',
CHAT_DELETED : 'chat_deleted',
GROUP_CHAT_DELETED: 'group_chat_deleted',
} }
export const eventSource = new EventEmitter(); export const eventSource = new EventEmitter();
@ -383,10 +385,7 @@ const system_message_types = {
}; };
const extension_prompt_types = { const extension_prompt_types = {
/** IN_PROMPT: 0,
* @deprecated Outdated term. In reality it's "after main prompt or story string"
*/
AFTER_SCENARIO: 0,
IN_CHAT: 1 IN_CHAT: 1
}; };
@ -1110,10 +1109,12 @@ async function delChat(chatfile) {
}); });
if (response.ok === true) { if (response.ok === true) {
// choose another chat if current was deleted // choose another chat if current was deleted
if (chatfile.replace('.jsonl', '') === characters[this_chid].chat) { const name = chatfile.replace('.jsonl', '');
if (name === characters[this_chid].chat) {
chat_metadata = {}; chat_metadata = {};
await replaceCurrentChat(); await replaceCurrentChat();
} }
await eventSource.emit(event_types.CHAT_DELETED, name);
} }
} }
@ -1148,10 +1149,42 @@ async function replaceCurrentChat() {
} }
} }
function printMessages() { const TRUNCATION_THRESHOLD = 100;
chat.forEach(function (item, i, arr) {
addOneMessage(item, { scroll: i === arr.length - 1 }); function showMoreMessages() {
}); let messageId = Number($('#chat').children('.mes').first().attr('mesid'));
let count = TRUNCATION_THRESHOLD;
console.debug('Inserting messages before', messageId, 'count', count, 'chat length', chat.length);
const prevHeight = $('#chat').prop('scrollHeight');
while(messageId > 0 && count > 0) {
count--;
messageId--;
addOneMessage(chat[messageId], { insertBefore: messageId + 1, scroll: false, forceId: messageId });
}
if (messageId == 0) {
$('#show_more_messages').remove();
}
const newHeight = $('#chat').prop('scrollHeight');
$('#chat').scrollTop(newHeight - prevHeight);
}
async function printMessages() {
let startIndex = 0;
if (chat.length > TRUNCATION_THRESHOLD) {
count_view_mes = chat.length - TRUNCATION_THRESHOLD;
startIndex = count_view_mes;
$('#chat').append('<div id="show_more_messages">Show more messages</div>');
}
for (let i = startIndex; i < chat.length; i++) {
const item = chat[i];
addOneMessage(item, { scroll: i === chat.length - 1 });
}
if (power_user.lazy_load > 0) { if (power_user.lazy_load > 0) {
const height = $('#chat').height(); const height = $('#chat').height();
@ -1194,7 +1227,7 @@ export async function reloadCurrentChat() {
} }
else { else {
resetChatState(); resetChatState();
printMessages(); await printMessages();
} }
await eventSource.emit(event_types.CHAT_CHANGED, getCurrentChatId()); await eventSource.emit(event_types.CHAT_CHANGED, getCurrentChatId());
@ -1430,7 +1463,7 @@ export function addCopyToCodeBlocks(messageElement) {
} }
function addOneMessage(mes, { type = "normal", insertAfter = null, scroll = true } = {}) { function addOneMessage(mes, { type = "normal", insertAfter = null, scroll = true, insertBefore = null, forceId = null } = {}) {
var messageText = mes["mes"]; var messageText = mes["mes"];
const momentDate = timestampToMoment(mes.send_date); const momentDate = timestampToMoment(mes.send_date);
const timestamp = momentDate.isValid() ? momentDate.format('LL LT') : ''; const timestamp = momentDate.isValid() ? momentDate.format('LL LT') : '';
@ -1500,7 +1533,7 @@ function addOneMessage(mes, { type = "normal", insertAfter = null, scroll = true
} }
}*/ }*/
let params = { let params = {
mesId: count_view_mes, mesId: forceId ?? count_view_mes,
characterName: characterName, characterName: characterName,
isUser: mes.is_user, isUser: mes.is_user,
avatarImg: avatarImg, avatarImg: avatarImg,
@ -1518,18 +1551,31 @@ function addOneMessage(mes, { type = "normal", insertAfter = null, scroll = true
const HTMLForEachMes = getMessageFromTemplate(params); const HTMLForEachMes = getMessageFromTemplate(params);
if (type !== 'swipe') { if (type !== 'swipe') {
if (!insertAfter) { if (!insertAfter && !insertBefore) {
$("#chat").append(HTMLForEachMes); $("#chat").append(HTMLForEachMes);
} }
else { else if (insertAfter) {
const target = $("#chat").find(`.mes[mesid="${insertAfter}"]`); const target = $("#chat").find(`.mes[mesid="${insertAfter}"]`);
$(HTMLForEachMes).insertAfter(target); $(HTMLForEachMes).insertAfter(target);
$(HTMLForEachMes).find('.swipe_left').css('display', 'none'); $(HTMLForEachMes).find('.swipe_left').css('display', 'none');
$(HTMLForEachMes).find('.swipe_right').css('display', 'none'); $(HTMLForEachMes).find('.swipe_right').css('display', 'none');
} else {
const target = $("#chat").find(`.mes[mesid="${insertBefore}"]`);
$(HTMLForEachMes).insertBefore(target);
$(HTMLForEachMes).find('.swipe_left').css('display', 'none');
$(HTMLForEachMes).find('.swipe_right').css('display', 'none');
} }
} }
const newMessageId = type == 'swipe' ? count_view_mes - 1 : count_view_mes; function getMessageId() {
if (typeof forceId == 'number') {
return forceId;
}
return type == 'swipe' ? count_view_mes - 1 : count_view_mes;
}
const newMessageId = getMessageId();
const newMessage = $(`#chat [mesid="${newMessageId}"]`); const newMessage = $(`#chat [mesid="${newMessageId}"]`);
const isSmallSys = mes?.extra?.isSmallSys; const isSmallSys = mes?.extra?.isSmallSys;
newMessage.data("isSystem", isSystem); newMessage.data("isSystem", isSystem);
@ -1603,6 +1649,11 @@ function addOneMessage(mes, { type = "normal", insertAfter = null, scroll = true
swipeMessage.find('.mes_timer').html(''); swipeMessage.find('.mes_timer').html('');
swipeMessage.find('.tokenCounterDisplay').html(''); swipeMessage.find('.tokenCounterDisplay').html('');
} }
} else if (typeof forceId == 'number') {
$("#chat").find(`[mesid="${forceId}"]`).find('.mes_text').append(messageText);
appendImageToMessage(mes, newMessage);
hideSwipeButtons();
showSwipeButtons();
} else { } else {
$("#chat").find(`[mesid="${count_view_mes}"]`).find('.mes_text').append(messageText); $("#chat").find(`[mesid="${count_view_mes}"]`).find('.mes_text').append(messageText);
appendImageToMessage(mes, newMessage); appendImageToMessage(mes, newMessage);
@ -1613,7 +1664,7 @@ function addOneMessage(mes, { type = "normal", insertAfter = null, scroll = true
addCopyToCodeBlocks(newMessage); addCopyToCodeBlocks(newMessage);
// Don't scroll if not inserting last // Don't scroll if not inserting last
if (!insertAfter && scroll) { if (!insertAfter && !insertBefore && scroll) {
$('#chat .mes').last().addClass('last_mes'); $('#chat .mes').last().addClass('last_mes');
$('#chat .mes').eq(-2).removeClass('last_mes'); $('#chat .mes').eq(-2).removeClass('last_mes');
@ -2534,7 +2585,7 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
addPersonaDescriptionExtensionPrompt(); addPersonaDescriptionExtensionPrompt();
// Call combined AN into Generate // Call combined AN into Generate
let allAnchors = getAllExtensionPrompts(); let allAnchors = getAllExtensionPrompts();
const afterScenarioAnchor = getExtensionPrompt(extension_prompt_types.AFTER_SCENARIO); const afterScenarioAnchor = getExtensionPrompt(extension_prompt_types.IN_PROMPT);
let zeroDepthAnchor = getExtensionPrompt(extension_prompt_types.IN_CHAT, 0, ' '); let zeroDepthAnchor = getExtensionPrompt(extension_prompt_types.IN_CHAT, 0, ' ');
const storyStringParams = { const storyStringParams = {
@ -3679,7 +3730,13 @@ function setInContextMessages(lastmsg, type) {
lastmsg++; lastmsg++;
} }
$('#chat .mes:not([is_system="true"])').eq(-lastmsg).addClass('lastInContext'); const lastMessageBlock = $('#chat .mes:not([is_system="true"])').eq(-lastmsg);
lastMessageBlock.addClass('lastInContext');
if (lastMessageBlock.length === 0) {
const firstMessageId = getFirstDisplayedMessageId();
$(`#chat .mes[mesid="${firstMessageId}"`).addClass('lastInContext');
}
} }
function getGenerateUrl() { function getGenerateUrl() {
@ -3784,7 +3841,7 @@ function cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncomplete
getMessage = getRegexedString(getMessage, isImpersonate ? regex_placement.USER_INPUT : regex_placement.AI_OUTPUT); getMessage = getRegexedString(getMessage, isImpersonate ? regex_placement.USER_INPUT : regex_placement.AI_OUTPUT);
if (!displayIncompleteSentences && power_user.trim_sentences) { if (!displayIncompleteSentences && power_user.trim_sentences) {
getMessage = end_trim_to_sentence(getMessage, power_user.include_newline); getMessage = trimToEndSentence(getMessage, power_user.include_newline);
} }
if (power_user.collapse_newlines) { if (power_user.collapse_newlines) {
@ -4452,7 +4509,7 @@ async function getChatResult() {
chat.push(message); chat.push(message);
await saveChatConditional(); await saveChatConditional();
} }
printMessages(); await printMessages();
select_selected_character(this_chid); select_selected_character(this_chid);
await eventSource.emit(event_types.CHAT_CHANGED, (getCurrentChatId())); await eventSource.emit(event_types.CHAT_CHANGED, (getCurrentChatId()));
@ -5014,12 +5071,12 @@ async function saveSettings(type) {
dataType: "json", dataType: "json",
contentType: "application/json", contentType: "application/json",
//processData: false, //processData: false,
success: function (data) { success: async function (data) {
//online_status = data.result; //online_status = data.result;
eventSource.emit(event_types.SETTINGS_UPDATED); eventSource.emit(event_types.SETTINGS_UPDATED);
if (type == "change_name") { if (type == "change_name") {
clearChat(); clearChat();
printMessages(); await printMessages();
} }
}, },
error: function (jqXHR, exception) { error: function (jqXHR, exception) {
@ -5605,7 +5662,7 @@ function select_rm_characters() {
* @param {number} position Insertion position. 0 is after story string, 1 is in-chat with custom depth. * @param {number} position Insertion position. 0 is after story string, 1 is in-chat with custom depth.
* @param {number} depth Insertion depth. 0 represets the last message in context. Expected values up to 100. * @param {number} depth Insertion depth. 0 represets the last message in context. Expected values up to 100.
*/ */
function setExtensionPrompt(key, value, position, depth) { export function setExtensionPrompt(key, value, position, depth) {
extension_prompts[key] = { value: String(value), position: Number(position), depth: Number(depth) }; extension_prompts[key] = { value: String(value), position: Number(position), depth: Number(depth) };
} }
@ -5893,9 +5950,11 @@ async function importCharacterChat(formData) {
} }
function updateViewMessageIds() { function updateViewMessageIds() {
const minId = getFirstDisplayedMessageId();
$('#chat').find(".mes").each(function (index, element) { $('#chat').find(".mes").each(function (index, element) {
$(element).attr("mesid", index); $(element).attr("mesid", minId + index);
$(element).find('.mesIDDisplay').text(`#${index}`); $(element).find('.mesIDDisplay').text(`#${minId + index}`);
}); });
$('#chat .mes').removeClass('last_mes'); $('#chat .mes').removeClass('last_mes');
@ -5904,6 +5963,12 @@ function updateViewMessageIds() {
updateEditArrowClasses(); updateEditArrowClasses();
} }
function getFirstDisplayedMessageId() {
const allIds = Array.from(document.querySelectorAll('#chat .mes')).map(el => Number(el.getAttribute('mesid'))).filter(x => !isNaN(x));
const minId = Math.min(...allIds);
return minId;
}
function updateEditArrowClasses() { function updateEditArrowClasses() {
$("#chat .mes .mes_edit_up").removeClass("disabled"); $("#chat .mes .mes_edit_up").removeClass("disabled");
$("#chat .mes .mes_edit_down").removeClass("disabled"); $("#chat .mes .mes_edit_down").removeClass("disabled");
@ -6306,7 +6371,7 @@ async function createOrEditCharacter(e) {
await eventSource.emit(event_types.MESSAGE_RECEIVED, (chat.length - 1)); await eventSource.emit(event_types.MESSAGE_RECEIVED, (chat.length - 1));
clearChat(); clearChat();
printMessages(); await printMessages();
await eventSource.emit(event_types.CHARACTER_MESSAGE_RENDERED, (chat.length - 1)); await eventSource.emit(event_types.CHARACTER_MESSAGE_RENDERED, (chat.length - 1));
await saveChatConditional(); await saveChatConditional();
} }
@ -6893,6 +6958,7 @@ export async function handleDeleteCharacter(popup_type, this_chid, delete_chats)
const avatar = characters[this_chid].avatar; const avatar = characters[this_chid].avatar;
const name = characters[this_chid].name; const name = characters[this_chid].name;
const pastChats = await getPastCharacterChats();
const msg = { avatar_url: avatar, delete_chats: delete_chats }; const msg = { avatar_url: avatar, delete_chats: delete_chats };
@ -6905,6 +6971,13 @@ export async function handleDeleteCharacter(popup_type, this_chid, delete_chats)
if (response.ok) { if (response.ok) {
await deleteCharacter(name, avatar); await deleteCharacter(name, avatar);
if (delete_chats) {
for (const chat of pastChats) {
const name = chat.file_name.replace('.jsonl', '');
await eventSource.emit(event_types.CHAT_DELETED, name);
}
}
} else { } else {
console.error('Failed to delete character: ', response.status, response.statusText); console.error('Failed to delete character: ', response.status, response.statusText);
} }
@ -6935,7 +7008,7 @@ export async function deleteCharacter(name, avatar) {
delete tag_map[avatar]; delete tag_map[avatar];
await getCharacters(); await getCharacters();
select_rm_info("char_delete", name); select_rm_info("char_delete", name);
printMessages(); await printMessages();
saveSettingsDebounced(); saveSettingsDebounced();
} }
@ -8835,6 +8908,10 @@ jQuery(async function () {
$('#avatar-and-name-block').slideToggle() $('#avatar-and-name-block').slideToggle()
}); });
$(document).on('mouseup touchend', '#show_more_messages', () => {
showMoreMessages();
});
// Added here to prevent execution before script.js is loaded and get rid of quirky timeouts // Added here to prevent execution before script.js is loaded and get rid of quirky timeouts
await firstLoadInit(); await firstLoadInit();

View File

@ -153,6 +153,8 @@ const extension_settings = {
}, },
speech_recognition: {}, speech_recognition: {},
rvc: {}, rvc: {},
hypebot: {},
vectors: {},
}; };
let modules = []; let modules = [];

View File

@ -3,11 +3,12 @@ import { dragElement, isMobile } from "../../RossAscends-mods.js";
import { getContext, getApiUrl, modules, extension_settings, ModuleWorkerWrapper, doExtrasFetch, renderExtensionTemplate } from "../../extensions.js"; import { getContext, getApiUrl, modules, extension_settings, ModuleWorkerWrapper, doExtrasFetch, renderExtensionTemplate } from "../../extensions.js";
import { loadMovingUIState, power_user } from "../../power-user.js"; import { loadMovingUIState, power_user } from "../../power-user.js";
import { registerSlashCommand } from "../../slash-commands.js"; import { registerSlashCommand } from "../../slash-commands.js";
import { onlyUnique, debounce, getCharaFilename } from "../../utils.js"; import { onlyUnique, debounce, getCharaFilename, trimToEndSentence, trimToStartSentence } from "../../utils.js";
export { MODULE_NAME }; export { MODULE_NAME };
const MODULE_NAME = 'expressions'; const MODULE_NAME = 'expressions';
const UPDATE_INTERVAL = 2000; const UPDATE_INTERVAL = 2000;
const STREAMING_UPDATE_INTERVAL = 6000;
const FALLBACK_EXPRESSION = 'joy'; const FALLBACK_EXPRESSION = 'joy';
const DEFAULT_EXPRESSIONS = [ const DEFAULT_EXPRESSIONS = [
"talkinghead", "talkinghead",
@ -46,6 +47,7 @@ let lastCharacter = undefined;
let lastMessage = null; let lastMessage = null;
let spriteCache = {}; let spriteCache = {};
let inApiCall = false; let inApiCall = false;
let lastServerResponseTime = 0;
function isVisualNovelMode() { function isVisualNovelMode() {
return Boolean(!isMobile() && power_user.waifuMode && getContext().groupId); return Boolean(!isMobile() && power_user.waifuMode && getContext().groupId);
@ -447,7 +449,7 @@ function handleImageChange() {
return; return;
} }
if (extension_settings.expressions.talkinghead) { if (extension_settings.expressions.talkinghead && !extension_settings.expressions.local) {
// Method get IP of endpoint // Method get IP of endpoint
const talkingheadResultFeedSrc = `${getApiUrl()}/api/talkinghead/result_feed`; const talkingheadResultFeedSrc = `${getApiUrl()}/api/talkinghead/result_feed`;
$('#expression-holder').css({ display: '' }); $('#expression-holder').css({ display: '' });
@ -477,6 +479,14 @@ function handleImageChange() {
async function moduleWorker() { async function moduleWorker() {
const context = getContext(); const context = getContext();
// Hide and disable talkinghead while in local mode
$('#image_type_block').toggle(!extension_settings.expressions.local);
if (extension_settings.expressions.local && extension_settings.expressions.talkinghead) {
$('#image_type_toggle').prop('checked', false);
setTalkingHeadState(false);
}
// non-characters not supported // non-characters not supported
if (!context.groupId && (context.characterId === undefined || context.characterId === 'invalid-safety-id')) { if (!context.groupId && (context.characterId === undefined || context.characterId === 'invalid-safety-id')) {
removeExpression(); removeExpression();
@ -530,7 +540,7 @@ async function moduleWorker() {
} }
const offlineMode = $('.expression_settings .offline_mode'); const offlineMode = $('.expression_settings .offline_mode');
if (!modules.includes('classify')) { if (!modules.includes('classify') && !extension_settings.expressions.local) {
$('.expression_settings').show(); $('.expression_settings').show();
offlineMode.css('display', 'block'); offlineMode.css('display', 'block');
lastCharacter = context.groupId || context.characterId; lastCharacter = context.groupId || context.characterId;
@ -566,6 +576,17 @@ async function moduleWorker() {
return; return;
} }
// Throttle classification requests during streaming
if (context.streamingProcessor && !context.streamingProcessor.isFinished) {
const now = Date.now();
const timeSinceLastServerResponse = now - lastServerResponseTime;
if (timeSinceLastServerResponse < STREAMING_UPDATE_INTERVAL) {
console.log('Streaming in progress: throttling expression update. Next update at ' + new Date(lastServerResponseTime + STREAMING_UPDATE_INTERVAL));
return;
}
}
try { try {
inApiCall = true; inApiCall = true;
let expression = await getExpressionLabel(currentLastMessage.mes); let expression = await getExpressionLabel(currentLastMessage.mes);
@ -583,7 +604,6 @@ async function moduleWorker() {
} }
await sendExpressionCall(spriteFolderName, expression, force, vnMode); await sendExpressionCall(spriteFolderName, expression, force, vnMode);
} }
catch (error) { catch (error) {
console.log(error); console.log(error);
@ -592,6 +612,7 @@ async function moduleWorker() {
inApiCall = false; inApiCall = false;
lastCharacter = context.groupId || context.characterId; lastCharacter = context.groupId || context.characterId;
lastMessage = currentLastMessage.mes; lastMessage = currentLastMessage.mes;
lastServerResponseTime = Date.now();
} }
} }
@ -635,6 +656,10 @@ function setTalkingHeadState(switch_var) {
extension_settings.expressions.talkinghead = switch_var; // Store setting extension_settings.expressions.talkinghead = switch_var; // Store setting
saveSettingsDebounced(); saveSettingsDebounced();
if (extension_settings.expressions.local) {
return;
}
talkingHeadCheck().then(result => { talkingHeadCheck().then(result => {
if (result) { if (result) {
//console.log("talkinghead exists!"); //console.log("talkinghead exists!");
@ -709,27 +734,77 @@ async function setSpriteSlashCommand(_, spriteId) {
await sendExpressionCall(spriteFolderName, spriteItem.label, true, vnMode); await sendExpressionCall(spriteFolderName, spriteItem.label, true, vnMode);
} }
/**
* Processes the classification text to reduce the amount of text sent to the API.
* Quotes and asterisks are to be removed. If the text is less than 300 characters, it is returned as is.
* If the text is more than 300 characters, the first and last 150 characters are returned.
* The result is trimmed to the end of sentence.
* @param {string} text The text to process.
* @returns {string}
*/
function sampleClassifyText(text) {
if (!text) {
return text;
}
// Remove asterisks and quotes
let result = text.replace(/[\*\"]/g, '');
const SAMPLE_THRESHOLD = 300;
const HALF_SAMPLE_THRESHOLD = SAMPLE_THRESHOLD / 2;
if (text.length < SAMPLE_THRESHOLD) {
result = trimToEndSentence(result);
} else {
result = trimToEndSentence(result.slice(0, HALF_SAMPLE_THRESHOLD)) + ' ' + trimToStartSentence(result.slice(-HALF_SAMPLE_THRESHOLD));
}
return result.trim();
}
async function getExpressionLabel(text) { async function getExpressionLabel(text) {
// Return if text is undefined, saving a costly fetch request // Return if text is undefined, saving a costly fetch request
if (!modules.includes('classify') || !text) { if ((!modules.includes('classify') && !extension_settings.expressions.local) || !text) {
return FALLBACK_EXPRESSION; return FALLBACK_EXPRESSION;
} }
const url = new URL(getApiUrl()); text = sampleClassifyText(text);
url.pathname = '/api/classify';
const apiResult = await doExtrasFetch(url, { try {
method: 'POST', if (extension_settings.expressions.local) {
headers: { // Local transformers pipeline
'Content-Type': 'application/json', const apiResult = await fetch('/api/extra/classify', {
'Bypass-Tunnel-Reminder': 'bypass', method: 'POST',
}, headers: getRequestHeaders(),
body: JSON.stringify({ text: text }), body: JSON.stringify({ text: text }),
}); });
if (apiResult.ok) { if (apiResult.ok) {
const data = await apiResult.json(); const data = await apiResult.json();
return data.classification[0].label; return data.classification[0].label;
}
} else {
// Extras
const url = new URL(getApiUrl());
url.pathname = '/api/classify';
const apiResult = await doExtrasFetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Bypass-Tunnel-Reminder': 'bypass',
},
body: JSON.stringify({ text: text }),
});
if (apiResult.ok) {
const data = await apiResult.json();
return data.classification[0].label;
}
}
} catch (error) {
console.log(error);
return FALLBACK_EXPRESSION;
} }
} }
@ -821,7 +896,7 @@ async function getSpritesList(name) {
async function getExpressionsList() { async function getExpressionsList() {
// get something for offline mode (default images) // get something for offline mode (default images)
if (!modules.includes('classify')) { if (!modules.includes('classify') && !extension_settings.expressions.local) {
return DEFAULT_EXPRESSIONS; return DEFAULT_EXPRESSIONS;
} }
@ -829,20 +904,34 @@ async function getExpressionsList() {
return expressionsList; return expressionsList;
} }
const url = new URL(getApiUrl());
url.pathname = '/api/classify/labels';
try { try {
const apiResult = await doExtrasFetch(url, { if (extension_settings.expressions.local) {
method: 'GET', const apiResult = await fetch('/api/extra/classify/labels', {
headers: { 'Bypass-Tunnel-Reminder': 'bypass' }, method: 'POST',
}); headers: getRequestHeaders(),
});
if (apiResult.ok) { if (apiResult.ok) {
const data = await apiResult.json();
expressionsList = data.labels;
return expressionsList;
}
} else {
const url = new URL(getApiUrl());
url.pathname = '/api/classify/labels';
const data = await apiResult.json(); const apiResult = await doExtrasFetch(url, {
expressionsList = data.labels; method: 'GET',
return expressionsList; headers: { 'Bypass-Tunnel-Reminder': 'bypass' },
});
if (apiResult.ok) {
const data = await apiResult.json();
expressionsList = data.labels;
return expressionsList;
}
} }
} }
catch (error) { catch (error) {
@ -852,7 +941,7 @@ async function getExpressionsList() {
} }
async function setExpression(character, expression, force) { async function setExpression(character, expression, force) {
if (!extension_settings.expressions.talkinghead) { if (extension_settings.expressions.local || !extension_settings.expressions.talkinghead) {
console.debug('entered setExpressions'); console.debug('entered setExpressions');
await validateImages(character); await validateImages(character);
const img = $('img.expression'); const img = $('img.expression');
@ -1226,6 +1315,11 @@ function setExpressionOverrideHtml(forceClear = false) {
$('#expressions_show_default').on('input', onExpressionsShowDefaultInput); $('#expressions_show_default').on('input', onExpressionsShowDefaultInput);
$('#expression_upload_pack_button').on('click', onClickExpressionUploadPackButton); $('#expression_upload_pack_button').on('click', onClickExpressionUploadPackButton);
$('#expressions_show_default').prop('checked', extension_settings.expressions.showDefault).trigger('input'); $('#expressions_show_default').prop('checked', extension_settings.expressions.showDefault).trigger('input');
$('#expression_local').prop('checked', extension_settings.expressions.local).on('input', function () {
extension_settings.expressions.local = !!$(this).prop('checked');
moduleWorker();
saveSettingsDebounced();
});
$('#expression_override_cleanup_button').on('click', onClickExpressionOverrideRemoveAllButton); $('#expression_override_cleanup_button').on('click', onClickExpressionOverrideRemoveAllButton);
$(document).on('dragstart', '.expression', (e) => { $(document).on('dragstart', '.expression', (e) => {
e.preventDefault() e.preventDefault()

View File

@ -6,14 +6,14 @@
</div> </div>
<div class="inline-drawer-content"> <div class="inline-drawer-content">
<!-- Toggle button for aituber/static images --> <label class="checkbox_label" for="expression_local" title="Use classification model without the Extras server.">
<div class="toggle_button"> <input id="expression_local" type="checkbox" />
<label class="switch"> <span data-i18n="Local server classification">Local server classification</span>
<input id="image_type_toggle" type="checkbox"> </label>
<span class="slider round"></span> <label id="image_type_block" class="checkbox_label" for="image_type_toggle">
<label for="image_type_toggle">Image Type - talkinghead (extras)</label> <input id="image_type_toggle" type="checkbox">
</label> <span>Image Type - talkinghead (extras)</span>
</div> </label>
<div class="offline_mode"> <div class="offline_mode">
<small>You are in offline mode. Click on the image below to set the expression.</small> <small>You are in offline mode. Click on the image below to set the expression.</small>
</div> </div>

View File

@ -739,7 +739,7 @@ window.chromadb_interceptGeneration = async (chat, maxContext) => {
// No memories? No prompt. // No memories? No prompt.
const promptBlob = (tokenApprox == 0) ? "" : wrapperMsg.replace('{{memories}}', allMemoryBlob); const promptBlob = (tokenApprox == 0) ? "" : wrapperMsg.replace('{{memories}}', allMemoryBlob);
console.debug("CHROMADB: prompt blob: %o", promptBlob); console.debug("CHROMADB: prompt blob: %o", promptBlob);
context.setExtensionPrompt(MODULE_NAME, promptBlob, extension_prompt_types.AFTER_SCENARIO); context.setExtensionPrompt(MODULE_NAME, promptBlob, extension_prompt_types.IN_PROMPT);
} }
if (selectedStrategy === 'custom') { if (selectedStrategy === 'custom') {
const context = getContext(); const context = getContext();

View File

@ -63,7 +63,7 @@ const defaultSettings = {
source: summary_sources.extras, source: summary_sources.extras,
prompt: defaultPrompt, prompt: defaultPrompt,
template: defaultTemplate, template: defaultTemplate,
position: extension_prompt_types.AFTER_SCENARIO, position: extension_prompt_types.IN_PROMPT,
depth: 2, depth: 2,
promptWords: 200, promptWords: 200,
promptMinWords: 25, promptMinWords: 25,

View File

@ -135,8 +135,8 @@ const languageCodes = {
'Zulu': 'zu', 'Zulu': 'zu',
}; };
const KEY_REQUIRED = ['deepl','libre']; const KEY_REQUIRED = ['deepl', 'libre'];
const LOCAL_URL = ['libre']; const LOCAL_URL = ['libre', 'oneringtranslator', 'deeplx'];
function showKeysButton() { function showKeysButton() {
const providerRequiresKey = KEY_REQUIRED.includes(extension_settings.translate.provider); const providerRequiresKey = KEY_REQUIRED.includes(extension_settings.translate.provider);
@ -144,6 +144,7 @@ function showKeysButton() {
$("#translate_key_button").toggle(providerRequiresKey); $("#translate_key_button").toggle(providerRequiresKey);
$("#translate_key_button").toggleClass('success', Boolean(secret_state[extension_settings.translate.provider])); $("#translate_key_button").toggleClass('success', Boolean(secret_state[extension_settings.translate.provider]));
$("#translate_url_button").toggle(providerOptionalUrl); $("#translate_url_button").toggle(providerOptionalUrl);
$("#translate_url_button").toggleClass('success', Boolean(secret_state[extension_settings.translate.provider + "_url"]));
} }
function loadSettings() { function loadSettings() {
@ -184,8 +185,33 @@ async function translateIncomingMessage(messageId) {
updateMessageBlock(messageId, message); updateMessageBlock(messageId, message);
} }
async function translateProviderOneRing(text, lang) {
let from_lang = lang == extension_settings.translate.internal_language
? extension_settings.translate.target_language
: extension_settings.translate.internal_language;
const response = await fetch('/api/translate/onering', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({ text: text, from_lang: from_lang, to_lang: lang }),
});
if (response.ok) {
const result = await response.text();
return result;
}
throw new Error(response.statusText);
}
/**
* Translates text using the LibreTranslate API
* @param {string} text Text to translate
* @param {string} lang Target language code
* @returns {Promise<string>} Translated text
*/
async function translateProviderLibre(text, lang) { async function translateProviderLibre(text, lang) {
const response = await fetch('/libre_translate', { const response = await fetch('/api/translate/libre', {
method: 'POST', method: 'POST',
headers: getRequestHeaders(), headers: getRequestHeaders(),
body: JSON.stringify({ text: text, lang: lang }), body: JSON.stringify({ text: text, lang: lang }),
@ -199,8 +225,14 @@ async function translateProviderLibre(text, lang) {
throw new Error(response.statusText); throw new Error(response.statusText);
} }
/**
* Translates text using the Google Translate API
* @param {string} text Text to translate
* @param {string} lang Target language code
* @returns {Promise<string>} Translated text
*/
async function translateProviderGoogle(text, lang) { async function translateProviderGoogle(text, lang) {
const response = await fetch('/google_translate', { const response = await fetch('/api/translate/google', {
method: 'POST', method: 'POST',
headers: getRequestHeaders(), headers: getRequestHeaders(),
body: JSON.stringify({ text: text, lang: lang }), body: JSON.stringify({ text: text, lang: lang }),
@ -214,12 +246,18 @@ async function translateProviderGoogle(text, lang) {
throw new Error(response.statusText); throw new Error(response.statusText);
} }
/**
* Translates text using the DeepL API
* @param {string} text Text to translate
* @param {string} lang Target language code
* @returns {Promise<string>} Translated text
*/
async function translateProviderDeepl(text, lang) { async function translateProviderDeepl(text, lang) {
if (!secret_state.deepl) { if (!secret_state.deepl) {
throw new Error('No DeepL API key'); throw new Error('No DeepL API key');
} }
const response = await fetch('/deepl_translate', { const response = await fetch('/api/translate/deepl', {
method: 'POST', method: 'POST',
headers: getRequestHeaders(), headers: getRequestHeaders(),
body: JSON.stringify({ text: text, lang: lang }), body: JSON.stringify({ text: text, lang: lang }),
@ -233,6 +271,33 @@ async function translateProviderDeepl(text, lang) {
throw new Error(response.statusText); throw new Error(response.statusText);
} }
/**
* Translates text using the DeepLX API
* @param {string} text Text to translate
* @param {string} lang Target language code
* @returns {Promise<string>} Translated text
*/
async function translateProviderDeepLX(text, lang) {
const response = await fetch('/api/translate/deeplx', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({ text: text, lang: lang }),
});
if (response.ok) {
const result = await response.text();
return result;
}
throw new Error(response.statusText);
}
/**
* Translates text using the selected translation provider
* @param {string} text Text to translate
* @param {string} lang Target language code
* @returns {Promise<string>} Translated text
*/
async function translate(text, lang) { async function translate(text, lang) {
try { try {
if (text == '') { if (text == '') {
@ -246,6 +311,10 @@ async function translate(text, lang) {
return await translateProviderGoogle(text, lang); return await translateProviderGoogle(text, lang);
case 'deepl': case 'deepl':
return await translateProviderDeepl(text, lang); return await translateProviderDeepl(text, lang);
case 'deeplx':
return await translateProviderDeepLX(text, lang);
case 'oneringtranslator':
return await translateProviderOneRing(text, lang);
default: default:
console.error('Unknown translation provider', extension_settings.translate.provider); console.error('Unknown translation provider', extension_settings.translate.provider);
return text; return text;
@ -391,6 +460,8 @@ jQuery(() => {
<option value="libre">Libre</option> <option value="libre">Libre</option>
<option value="google">Google</option> <option value="google">Google</option>
<option value="deepl">DeepL</option> <option value="deepl">DeepL</option>
<option value="deeplx">DeepLX</option>
<option value="oneringtranslator">OneRingTranslator</option>
<select> <select>
<div id="translate_key_button" class="menu_button fa-solid fa-key margin0"></div> <div id="translate_key_button" class="menu_button fa-solid fa-key margin0"></div>
<div id="translate_url_button" class="menu_button fa-solid fa-link margin0"></div> <div id="translate_url_button" class="menu_button fa-solid fa-link margin0"></div>
@ -443,12 +514,17 @@ jQuery(() => {
await writeSecret(extension_settings.translate.provider, key); await writeSecret(extension_settings.translate.provider, key);
toastr.success('API Key saved'); toastr.success('API Key saved');
$("#translate_key_button").addClass('success');
}); });
$('#translate_url_button').on('click', async () => { $('#translate_url_button').on('click', async () => {
const optionText = $('#translation_provider option:selected').text(); const optionText = $('#translation_provider option:selected').text();
const exampleURLs = {}; const exampleURLs = {
exampleURLs['libre'] = 'http://127.0.0.1:5000/translate'; 'libre': 'http://127.0.0.1:5000/translate',
const url = await callPopup(`<h3>${optionText} API URL</h3><i>Example: <tt>` + exampleURLs[extension_settings.translate.provider] + `</tt></i>`, 'input'); 'oneringtranslator': 'http://127.0.0.1:4990/translate',
'deeplx': 'http://127.0.0.1:1188/translate',
};
const popupText = `<h3>${optionText} API URL</h3><i>Example: <tt>${String(exampleURLs[extension_settings.translate.provider])}</tt></i>`;
const url = await callPopup(popupText, 'input');
if (url == false) { if (url == false) {
return; return;
@ -456,6 +532,7 @@ jQuery(() => {
await writeSecret(extension_settings.translate.provider + "_url", url); await writeSecret(extension_settings.translate.provider + "_url", url);
toastr.success('API URL saved'); toastr.success('API URL saved');
$("#translate_url_button").addClass('success');
}); });
loadSettings(); loadSettings();

View File

@ -0,0 +1,434 @@
import { eventSource, event_types, extension_prompt_types, getCurrentChatId, getRequestHeaders, is_send_press, saveSettingsDebounced, setExtensionPrompt, substituteParams } from "../../../script.js";
import { ModuleWorkerWrapper, extension_settings, getContext, renderExtensionTemplate } from "../../extensions.js";
import { collapseNewlines, power_user, ui_mode } from "../../power-user.js";
import { debounce, getStringHash as calculateHash, waitUntilCondition } from "../../utils.js";
const MODULE_NAME = 'vectors';
export const EXTENSION_PROMPT_TAG = '3_vectors';
const settings = {
enabled: false,
source: 'local',
template: `Past events: {{text}}`,
depth: 2,
position: extension_prompt_types.IN_PROMPT,
protect: 5,
insert: 3,
query: 2,
};
const moduleWorker = new ModuleWorkerWrapper(synchronizeChat);
async function onVectorizeAllClick() {
try {
if (!settings.enabled) {
return;
}
const chatId = getCurrentChatId();
if (!chatId) {
toastr.info('No chat selected', 'Vectorization aborted');
return;
}
const batchSize = 5;
const elapsedLog = [];
let finished = false;
$('#vectorize_progress').show();
$('#vectorize_progress_percent').text('0');
$('#vectorize_progress_eta').text('...');
while (!finished) {
if (is_send_press) {
toastr.info('Message generation is in progress.', 'Vectorization aborted');
throw new Error('Message generation is in progress.');
}
const startTime = Date.now();
const remaining = await synchronizeChat(batchSize);
const elapsed = Date.now() - startTime;
elapsedLog.push(elapsed);
finished = remaining <= 0;
const total = getContext().chat.length;
const processed = total - remaining;
const processedPercent = Math.round((processed / total) * 100); // percentage of the work done
const lastElapsed = elapsedLog.slice(-5); // last 5 elapsed times
const averageElapsed = lastElapsed.reduce((a, b) => a + b, 0) / lastElapsed.length; // average time needed to process one item
const pace = averageElapsed / batchSize; // time needed to process one item
const remainingTime = Math.round(pace * remaining / 1000);
$('#vectorize_progress_percent').text(processedPercent);
$('#vectorize_progress_eta').text(remainingTime);
if (chatId !== getCurrentChatId()) {
throw new Error('Chat changed');
}
}
} catch (error) {
console.error('Vectors: Failed to vectorize all', error);
} finally {
$('#vectorize_progress').hide();
}
}
let syncBlocked = false;
async function synchronizeChat(batchSize = 5) {
if (!settings.enabled) {
return -1;
}
try {
await waitUntilCondition(() => !syncBlocked && !is_send_press, 1000);
} catch {
console.log('Vectors: Synchronization blocked by another process');
return -1;
}
try {
syncBlocked = true;
const context = getContext();
const chatId = getCurrentChatId();
if (!chatId || !Array.isArray(context.chat)) {
console.debug('Vectors: No chat selected');
return -1;
}
const hashedMessages = context.chat.filter(x => !x.is_system).map(x => ({ text: String(x.mes), hash: getStringHash(x.mes) }));
const hashesInCollection = await getSavedHashes(chatId);
const newVectorItems = hashedMessages.filter(x => !hashesInCollection.includes(x.hash));
const deletedHashes = hashesInCollection.filter(x => !hashedMessages.some(y => y.hash === x));
if (newVectorItems.length > 0) {
console.log(`Vectors: Found ${newVectorItems.length} new items. Processing ${batchSize}...`);
await insertVectorItems(chatId, newVectorItems.slice(0, batchSize));
}
if (deletedHashes.length > 0) {
await deleteVectorItems(chatId, deletedHashes);
console.log(`Vectors: Deleted ${deletedHashes.length} old hashes`);
}
return newVectorItems.length - batchSize;
} catch (error) {
toastr.error('Check server console for more details', 'Vectorization failed');
console.error('Vectors: Failed to synchronize chat', error);
return -1;
} finally {
syncBlocked = false;
}
}
// Cache object for storing hash values
const hashCache = {};
/**
* Gets the hash value for a given string
* @param {string} str Input string
* @returns {number} Hash value
*/
function getStringHash(str) {
// Check if the hash is already in the cache
if (hashCache.hasOwnProperty(str)) {
return hashCache[str];
}
// Calculate the hash value
const hash = calculateHash(str);
// Store the hash in the cache
hashCache[str] = hash;
return hash;
}
/**
* Removes the most relevant messages from the chat and displays them in the extension prompt
* @param {object[]} chat Array of chat messages
*/
async function rearrangeChat(chat) {
try {
// Clear the extension prompt
setExtensionPrompt(EXTENSION_PROMPT_TAG, '', extension_prompt_types.IN_PROMPT, 0);
if (!settings.enabled) {
return;
}
const chatId = getCurrentChatId();
if (!chatId || !Array.isArray(chat)) {
console.debug('Vectors: No chat selected');
return;
}
if (chat.length < settings.protect) {
console.debug(`Vectors: Not enough messages to rearrange (less than ${settings.protect})`);
return;
}
const queryText = getQueryText(chat);
if (queryText.length === 0) {
console.debug('Vectors: No text to query');
return;
}
// Get the most relevant messages, excluding the last few
const queryHashes = await queryCollection(chatId, queryText, settings.insert);
const queriedMessages = [];
const retainMessages = chat.slice(-settings.protect);
for (const message of chat) {
if (retainMessages.includes(message)) {
continue;
}
if (message.mes && queryHashes.includes(getStringHash(message.mes))) {
queriedMessages.push(message);
}
}
// Rearrange queried messages to match query order
// Order is reversed because more relevant are at the lower indices
queriedMessages.sort((a, b) => queryHashes.indexOf(getStringHash(b.mes)) - queryHashes.indexOf(getStringHash(a.mes)));
// Remove queried messages from the original chat array
for (const message of chat) {
if (queriedMessages.includes(message)) {
chat.splice(chat.indexOf(message), 1);
}
}
if (queriedMessages.length === 0) {
console.debug('Vectors: No relevant messages found');
return;
}
// Format queried messages into a single string
const insertedText = getPromptText(queriedMessages);
setExtensionPrompt(EXTENSION_PROMPT_TAG, insertedText, settings.position, settings.depth);
} catch (error) {
console.error('Vectors: Failed to rearrange chat', error);
}
}
/**
* @param {any[]} queriedMessages
* @returns {string}
*/
function getPromptText(queriedMessages) {
const queriedText = queriedMessages.map(x => collapseNewlines(`${x.name}: ${x.mes}`).trim()).join('\n\n');
console.log('Vectors: relevant past messages found.\n', queriedText);
return substituteParams(settings.template.replace(/{{text}}/i, queriedText));
}
window['vectors_rearrangeChat'] = rearrangeChat;
const onChatEvent = debounce(async () => await moduleWorker.update(), 500);
/**
* Gets the text to query from the chat
* @param {object[]} chat Chat messages
* @returns {string} Text to query
*/
function getQueryText(chat) {
let queryText = '';
let i = 0;
for (const message of chat.slice().reverse()) {
if (message.mes) {
queryText += message.mes + '\n';
i++;
}
if (i === settings.query) {
break;
}
}
return collapseNewlines(queryText).trim();
}
/**
* Gets the saved hashes for a collection
* @param {string} collectionId
* @returns {Promise<number[]>} Saved hashes
*/
async function getSavedHashes(collectionId) {
const response = await fetch('/api/vector/list', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({
collectionId: collectionId,
source: settings.source,
}),
});
if (!response.ok) {
throw new Error(`Failed to get saved hashes for collection ${collectionId}`);
}
const hashes = await response.json();
return hashes;
}
/**
* Inserts vector items into a collection
* @param {string} collectionId - The collection to insert into
* @param {{ hash: number, text: string }[]} items - The items to insert
* @returns {Promise<void>}
*/
async function insertVectorItems(collectionId, items) {
const response = await fetch('/api/vector/insert', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({
collectionId: collectionId,
items: items,
source: settings.source,
}),
});
if (!response.ok) {
throw new Error(`Failed to insert vector items for collection ${collectionId}`);
}
}
/**
* Deletes vector items from a collection
* @param {string} collectionId - The collection to delete from
* @param {number[]} hashes - The hashes of the items to delete
* @returns {Promise<void>}
*/
async function deleteVectorItems(collectionId, hashes) {
const response = await fetch('/api/vector/delete', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({
collectionId: collectionId,
hashes: hashes,
source: settings.source,
}),
});
if (!response.ok) {
throw new Error(`Failed to delete vector items for collection ${collectionId}`);
}
}
/**
* @param {string} collectionId - The collection to query
* @param {string} searchText - The text to query
* @param {number} topK - The number of results to return
* @returns {Promise<number[]>} - Hashes of the results
*/
async function queryCollection(collectionId, searchText, topK) {
const response = await fetch('/api/vector/query', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({
collectionId: collectionId,
searchText: searchText,
topK: topK,
source: settings.source,
}),
});
if (!response.ok) {
throw new Error(`Failed to query collection ${collectionId}`);
}
const results = await response.json();
return results;
}
async function purgeVectorIndex(collectionId) {
try {
if (!settings.enabled) {
return;
}
const response = await fetch('/api/vector/purge', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({
collectionId: collectionId,
}),
});
if (!response.ok) {
throw new Error(`Could not delete vector index for collection ${collectionId}`);
}
console.log(`Vectors: Purged vector index for collection ${collectionId}`);
} catch (error) {
console.error('Vectors: Failed to purge', error);
}
}
jQuery(async () => {
if (!extension_settings.vectors) {
extension_settings.vectors = settings;
}
Object.assign(settings, extension_settings.vectors);
$('#extensions_settings2').append(renderExtensionTemplate(MODULE_NAME, 'settings'));
$('#vectors_enabled').prop('checked', settings.enabled).on('input', () => {
settings.enabled = $('#vectors_enabled').prop('checked');
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_source').val(settings.source).on('change', () => {
settings.source = String($('#vectors_source').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_template').val(settings.template).on('input', () => {
settings.template = String($('#vectors_template').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_depth').val(settings.depth).on('input', () => {
settings.depth = Number($('#vectors_depth').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_protect').val(settings.protect).on('input', () => {
settings.protect = Number($('#vectors_protect').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_insert').val(settings.insert).on('input', () => {
settings.insert = Number($('#vectors_insert').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_query').val(settings.query).on('input', () => {
settings.query = Number($('#vectors_query').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$(`input[name="vectors_position"][value="${settings.position}"]`).prop('checked', true);
$('input[name="vectors_position"]').on('change', () => {
settings.position = Number($('input[name="vectors_position"]:checked').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_advanced_settings').toggleClass('displayNone', power_user.ui_mode === ui_mode.SIMPLE);
$('#vectors_vectorize_all').on('click', onVectorizeAllClick);
eventSource.on(event_types.MESSAGE_DELETED, onChatEvent);
eventSource.on(event_types.MESSAGE_EDITED, onChatEvent);
eventSource.on(event_types.MESSAGE_SENT, onChatEvent);
eventSource.on(event_types.MESSAGE_RECEIVED, onChatEvent);
eventSource.on(event_types.MESSAGE_SWIPED, onChatEvent);
eventSource.on(event_types.CHAT_DELETED, purgeVectorIndex);
eventSource.on(event_types.GROUP_CHAT_DELETED, purgeVectorIndex);
});

View File

@ -0,0 +1,12 @@
{
"display_name": "Vector Storage",
"loading_order": 100,
"requires": [],
"optional": [],
"generate_interceptor": "vectors_rearrangeChat",
"js": "index.js",
"css": "",
"author": "Cohee#1207",
"version": "1.0.0",
"homePage": "https://github.com/SillyTavern/SillyTavern"
}

View File

@ -0,0 +1,71 @@
<div class="vectors_settings">
<div class="inline-drawer">
<div class="inline-drawer-toggle inline-drawer-header">
<b>Vector Storage</b>
<div class="inline-drawer-icon fa-solid fa-circle-chevron-down down"></div>
</div>
<div class="inline-drawer-content">
<label class="checkbox_label" for="vectors_enabled">
<input id="vectors_enabled" type="checkbox" class="checkbox">
Enabled
</label>
<label for="vectors_source">
Vectorization Source
</label>
<select id="vectors_source" class="select">
<option value="local">Local</option>
<option value="openai">OpenAI</option>
</select>
<div id="vectors_advanced_settings" data-newbie-hidden>
<label for="vectors_template">
Insertion template:
</label>
<textarea id="vectors_template" class="text_pole textarea_compact autoSetHeight" rows="2" placeholder="Use {{text}} macro to specify the position of retrieved text."></textarea>
<label for="vectors_position">Injection position:</label>
<div class="radio_group">
<label>
<input type="radio" name="vectors_position" value="0" />
After Main Prompt / Story String
</label>
<label>
<input type="radio" name="vectors_position" value="1" />
In-chat @ Depth <input id="vectors_depth" class="text_pole widthUnset" type="number" min="0" max="99" />
</label>
</div>
<div class="flex-container">
<div class="flex1" title="Prevents last N messages from being placed out of order.">
<label for="vectors_protect">
<small>Retain#</small>
</label>
<input type="number" id="vectors_protect" class="text_pole widthUnset" min="1" max="99" />
</div>
<div class="flex1" title="How many last messages will be matched for relevance.">
<label for="vectors_query">
<small>Query#</small>
</label>
<input type="number" id="vectors_query" class="text_pole widthUnset" min="1" max="99" />
</div>
<div class="flex1" title="How many past messages to insert as memories.">
<label for="vectors_insert">
<small>Insert#</small>
</label>
<input type="number" id="vectors_insert" class="text_pole widthUnset" min="1" max="99" />
</div>
</div>
</div>
<small>
Old messages are vectorized gradually as you chat.
To process all previous messages, click the button below.
</small>
<div id="vectors_vectorize_all" class="menu_button menu_button_icon">
Vectorize All
</div>
<div id="vectorize_progress" style="display: none;">
<small>
Processed <span id="vectorize_progress_percent">0</span>% of messages.
ETA: <span id="vectorize_progress_eta">...</span> seconds.
</small>
</div>
</div>
</div>
</div>

View File

@ -166,7 +166,7 @@ export async function getGroupChat(groupId) {
for (let key of data) { for (let key of data) {
chat.push(key); chat.push(key);
} }
printMessages(); await printMessages();
} else { } else {
sendSystemMessage(system_message_types.GROUP, '', { isSmallSys: true }); sendSystemMessage(system_message_types.GROUP, '', { isSmallSys: true });
if (group && Array.isArray(group.members)) { if (group && Array.isArray(group.members)) {
@ -816,18 +816,26 @@ function activateNaturalOrder(members, input, lastMessage, allowSelfResponses, i
} }
async function deleteGroup(id) { async function deleteGroup(id) {
const group = groups.find((x) => x.id === id);
const response = await fetch("/deletegroup", { const response = await fetch("/deletegroup", {
method: "POST", method: "POST",
headers: getRequestHeaders(), headers: getRequestHeaders(),
body: JSON.stringify({ id: id }), body: JSON.stringify({ id: id }),
}); });
if (group && Array.isArray(group.chats)) {
for (const chatId of group.chats) {
await eventSource.emit(event_types.GROUP_CHAT_DELETED, chatId);
}
}
if (response.ok) { if (response.ok) {
selected_group = null; selected_group = null;
delete tag_map[id]; delete tag_map[id];
resetChatState(); resetChatState();
clearChat(); clearChat();
printMessages(); await printMessages();
await getCharacters(); await getCharacters();
select_rm_info("group_delete", id); select_rm_info("group_delete", id);
@ -1493,6 +1501,8 @@ export async function deleteGroupChat(groupId, chatId) {
} else { } else {
await createNewGroupChat(groupId); await createNewGroupChat(groupId);
} }
await eventSource.emit(event_types.GROUP_CHAT_DELETED, chatId);
} }
} }

View File

@ -630,6 +630,12 @@ function populateChatCompletion(prompts, chatCompletion, { bias, quietPrompt, ty
if (true === afterScenario) chatCompletion.insert(authorsNote, 'scenario'); if (true === afterScenario) chatCompletion.insert(authorsNote, 'scenario');
} }
// Vectors Memory
if (prompts.has('vectorsMemory')) {
const vectorsMemory = Message.fromPrompt(prompts.get('vectorsMemory'));
chatCompletion.insert(vectorsMemory, 'main');
}
// Decide whether dialogue examples should always be added // Decide whether dialogue examples should always be added
if (power_user.pin_examples) { if (power_user.pin_examples) {
populateDialogueExamples(prompts, chatCompletion); populateDialogueExamples(prompts, chatCompletion);
@ -697,6 +703,14 @@ function preparePromptsForChatCompletion({Scenario, charPersonality, name2, worl
identifier: 'authorsNote' identifier: 'authorsNote'
}); });
// Vectors Memory
const vectorsMemory = extensionPrompts['3_vectors'];
if (vectorsMemory && vectorsMemory.value) systemPrompts.push({
role: 'system',
content: vectorsMemory.value,
identifier: 'vectorsMemory',
});
// Persona Description // Persona Description
if (power_user.persona_description && power_user.persona_description_position === persona_description_positions.IN_PROMPT) { if (power_user.persona_description && power_user.persona_description_position === persona_description_positions.IN_PROMPT) {
systemPrompts.push({ role: 'system', content: power_user.persona_description, identifier: 'personaDescription' }); systemPrompts.push({ role: 'system', content: power_user.persona_description, identifier: 'personaDescription' });

View File

@ -438,9 +438,9 @@ export function sortByCssOrder(a, b) {
* @param {boolean} include_newline Whether to include a newline character in the trimmed string. * @param {boolean} include_newline Whether to include a newline character in the trimmed string.
* @returns {string} The trimmed string. * @returns {string} The trimmed string.
* @example * @example
* end_trim_to_sentence('Hello, world! I am from'); // 'Hello, world!' * trimToEndSentence('Hello, world! I am from'); // 'Hello, world!'
*/ */
export function end_trim_to_sentence(input, include_newline = false) { export function trimToEndSentence(input, include_newline = false) {
const punctuation = new Set(['.', '!', '?', '*', '"', ')', '}', '`', ']', '$', '。', '', '', '”', '', '】', '】', '', '」', '】']); // extend this as you see fit const punctuation = new Set(['.', '!', '?', '*', '"', ')', '}', '`', ']', '$', '。', '', '', '”', '', '】', '】', '', '」', '】']); // extend this as you see fit
let last = -1; let last = -1;
@ -465,6 +465,26 @@ export function end_trim_to_sentence(input, include_newline = false) {
return input.substring(0, last + 1).trimEnd(); return input.substring(0, last + 1).trimEnd();
} }
export function trimToStartSentence(input) {
let p1 = input.indexOf(".");
let p2 = input.indexOf("!");
let p3 = input.indexOf("?");
let p4 = input.indexOf("\n");
let first = p1;
let skip1 = false;
if (p2 > 0 && p2 < first) { first = p2; }
if (p3 > 0 && p3 < first) { first = p3; }
if (p4 > 0 && p4 < first) { first = p4; skip1 = true; }
if (first > 0) {
if (skip1) {
return input.substring(first + 1);
} else {
return input.substring(first + 2);
}
}
return input;
}
/** /**
* Counts the number of occurrences of a character in a string. * Counts the number of occurrences of a character in a string.
* @param {string} string The string to count occurrences in. * @param {string} string The string to count occurrences in.

View File

@ -1209,6 +1209,7 @@ input[type="file"] {
.radio_group { .radio_group {
display: flex; display: flex;
flex-direction: column; flex-direction: column;
margin-top: 5px;
} }
#extension_floating_counter { #extension_floating_counter {
@ -3580,6 +3581,15 @@ a {
align-self: center; align-self: center;
} }
#show_more_messages {
text-align: center;
margin: 10px 0;
font-weight: 500;
text-decoration: underline;
order: -1;
cursor: pointer;
}
#select_chat_search { #select_chat_search {
background-color: transparent; background-color: transparent;
border: none; border: none;

249
server.js
View File

@ -58,7 +58,6 @@ const { Tokenizer } = require('@agnai/web-tokenizers');
// misc/other imports // misc/other imports
const _ = require('lodash'); const _ = require('lodash');
const { generateRequestUrl, normaliseResponse } = require('google-translate-api-browser');
// Unrestrict console logs display limit // Unrestrict console logs display limit
util.inspect.defaultOptions.maxArrayLength = null; util.inspect.defaultOptions.maxArrayLength = null;
@ -74,6 +73,7 @@ const characterCardParser = require('./src/character-card-parser.js');
const contentManager = require('./src/content-manager'); const contentManager = require('./src/content-manager');
const novelai = require('./src/novelai'); const novelai = require('./src/novelai');
const statsHelpers = require('./statsHelpers.js'); const statsHelpers = require('./statsHelpers.js');
const { writeSecret, readSecret, readSecretState, migrateSecrets, SECRET_KEYS, getAllSecrets } = require('./src/secrets');
function createDefaultFiles() { function createDefaultFiles() {
const files = { const files = {
@ -338,6 +338,7 @@ function humanizedISO8601DateTime(date) {
var charactersPath = 'public/characters/'; var charactersPath = 'public/characters/';
var chatsPath = 'public/chats/'; var chatsPath = 'public/chats/';
const UPLOADS_PATH = './uploads'; const UPLOADS_PATH = './uploads';
const SETTINGS_FILE = './public/settings.json';
const AVATAR_WIDTH = 400; const AVATAR_WIDTH = 400;
const AVATAR_HEIGHT = 600; const AVATAR_HEIGHT = 600;
const jsonParser = express.json({ limit: '100mb' }); const jsonParser = express.json({ limit: '100mb' });
@ -4112,7 +4113,7 @@ const setupTasks = async function () {
console.log(`SillyTavern ${version.pkgVersion}` + (version.gitBranch ? ` '${version.gitBranch}' (${version.gitRevision})` : '')); console.log(`SillyTavern ${version.pkgVersion}` + (version.gitBranch ? ` '${version.gitBranch}' (${version.gitRevision})` : ''));
backupSettings(); backupSettings();
migrateSecrets(); migrateSecrets(SETTINGS_FILE);
ensurePublicDirectoriesExist(); ensurePublicDirectoriesExist();
await ensureThumbnailCache(); await ensureThumbnailCache();
contentManager.checkForNewContent(); contentManager.checkForNewContent();
@ -4263,69 +4264,6 @@ function ensurePublicDirectoriesExist() {
} }
} }
const SECRETS_FILE = './secrets.json';
const SETTINGS_FILE = './public/settings.json';
const SECRET_KEYS = {
HORDE: 'api_key_horde',
MANCER: 'api_key_mancer',
OPENAI: 'api_key_openai',
NOVEL: 'api_key_novel',
CLAUDE: 'api_key_claude',
DEEPL: 'deepl',
LIBRE: 'libre',
LIBRE_URL: 'libre_url',
OPENROUTER: 'api_key_openrouter',
SCALE: 'api_key_scale',
AI21: 'api_key_ai21',
SCALE_COOKIE: 'scale_cookie',
}
function migrateSecrets() {
if (!fs.existsSync(SETTINGS_FILE)) {
console.log('Settings file does not exist');
return;
}
try {
let modified = false;
const fileContents = fs.readFileSync(SETTINGS_FILE, 'utf8');
const settings = JSON.parse(fileContents);
const oaiKey = settings?.api_key_openai;
const hordeKey = settings?.horde_settings?.api_key;
const novelKey = settings?.api_key_novel;
if (typeof oaiKey === 'string') {
console.log('Migrating OpenAI key...');
writeSecret(SECRET_KEYS.OPENAI, oaiKey);
delete settings.api_key_openai;
modified = true;
}
if (typeof hordeKey === 'string') {
console.log('Migrating Horde key...');
writeSecret(SECRET_KEYS.HORDE, hordeKey);
delete settings.horde_settings.api_key;
modified = true;
}
if (typeof novelKey === 'string') {
console.log('Migrating Novel key...');
writeSecret(SECRET_KEYS.NOVEL, novelKey);
delete settings.api_key_novel;
modified = true;
}
if (modified) {
console.log('Writing updated settings.json...');
const settingsContent = JSON.stringify(settings);
writeFileAtomicSync(SETTINGS_FILE, settingsContent, "utf-8");
}
}
catch (error) {
console.error('Could not migrate secrets file. Proceed with caution.');
}
}
app.post('/writesecret', jsonParser, (request, response) => { app.post('/writesecret', jsonParser, (request, response) => {
const key = request.body.key; const key = request.body.key;
const value = request.body.value; const value = request.body.value;
@ -4335,19 +4273,9 @@ app.post('/writesecret', jsonParser, (request, response) => {
}); });
app.post('/readsecretstate', jsonParser, (_, response) => { app.post('/readsecretstate', jsonParser, (_, response) => {
if (!fs.existsSync(SECRETS_FILE)) {
return response.send({});
}
try { try {
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf8'); const state = readSecretState();
const secrets = JSON.parse(fileContents);
const state = {};
for (const key of Object.values(SECRET_KEYS)) {
state[key] = !!secrets[key]; // convert to boolean
}
return response.send(state); return response.send(state);
} catch (error) { } catch (error) {
console.error(error); console.error(error);
@ -4393,14 +4321,13 @@ app.post('/viewsecrets', jsonParser, async (_, response) => {
return response.sendStatus(403); return response.sendStatus(403);
} }
if (!fs.existsSync(SECRETS_FILE)) {
console.error('secrets.json does not exist');
return response.sendStatus(404);
}
try { try {
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf-8'); const secrets = getAllSecrets();
const secrets = JSON.parse(fileContents);
if (!secrets) {
return response.sendStatus(404);
}
return response.send(secrets); return response.send(secrets);
} catch (error) { } catch (error) {
console.error(error); console.error(error);
@ -4835,122 +4762,6 @@ app.post('/api/sd/generate', jsonParser, async (request, response) => {
} }
}); });
app.post('/libre_translate', jsonParser, async (request, response) => {
const key = readSecret(SECRET_KEYS.LIBRE);
const url = readSecret(SECRET_KEYS.LIBRE_URL);
const text = request.body.text;
const lang = request.body.lang;
if (!text || !lang) {
return response.sendStatus(400);
}
console.log('Input text: ' + text);
try {
const result = await fetch(url, {
method: "POST",
body: JSON.stringify({
q: text,
source: "auto",
target: lang,
format: "text",
api_key: key
}),
headers: { "Content-Type": "application/json" }
});
if (!result.ok) {
return response.sendStatus(result.status);
}
const json = await result.json();
console.log('Translated text: ' + json.translatedText);
return response.send(json.translatedText);
} catch (error) {
console.log("Translation error: " + error.message);
return response.sendStatus(500);
}
});
app.post('/google_translate', jsonParser, async (request, response) => {
const text = request.body.text;
const lang = request.body.lang;
if (!text || !lang) {
return response.sendStatus(400);
}
console.log('Input text: ' + text);
const url = generateRequestUrl(text, { to: lang });
https.get(url, (resp) => {
let data = '';
resp.on('data', (chunk) => {
data += chunk;
});
resp.on('end', () => {
const result = normaliseResponse(JSON.parse(data));
console.log('Translated text: ' + result.text);
return response.send(result.text);
});
}).on("error", (err) => {
console.log("Translation error: " + err.message);
return response.sendStatus(500);
});
});
app.post('/deepl_translate', jsonParser, async (request, response) => {
const key = readSecret(SECRET_KEYS.DEEPL);
if (!key) {
return response.sendStatus(401);
}
const text = request.body.text;
const lang = request.body.lang;
if (!text || !lang) {
return response.sendStatus(400);
}
console.log('Input text: ' + text);
const params = new URLSearchParams();
params.append('text', text);
params.append('target_lang', lang);
try {
const result = await fetch('https://api-free.deepl.com/v2/translate', {
method: 'POST',
body: params,
headers: {
'Accept': 'application/json',
'Authorization': `DeepL-Auth-Key ${key}`,
'Content-Type': 'application/x-www-form-urlencoded',
},
timeout: 0,
});
if (!result.ok) {
return response.sendStatus(result.status);
}
const json = await result.json();
console.log('Translated text: ' + json.translations[0].text);
return response.send(json.translations[0].text);
} catch (error) {
console.log("Translation error: " + error.message);
return response.sendStatus(500);
}
});
app.post('/novel_tts', jsonParser, async (request, response) => { app.post('/novel_tts', jsonParser, async (request, response) => {
const token = readSecret(SECRET_KEYS.NOVEL); const token = readSecret(SECRET_KEYS.NOVEL);
@ -5334,27 +5145,6 @@ function importRisuSprites(data) {
} }
} }
function writeSecret(key, value) {
if (!fs.existsSync(SECRETS_FILE)) {
const emptyFile = JSON.stringify({});
writeFileAtomicSync(SECRETS_FILE, emptyFile, "utf-8");
}
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf-8');
const secrets = JSON.parse(fileContents);
secrets[key] = value;
writeFileAtomicSync(SECRETS_FILE, JSON.stringify(secrets), "utf-8");
}
function readSecret(key) {
if (!fs.existsSync(SECRETS_FILE)) {
return undefined;
}
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf-8');
const secrets = JSON.parse(fileContents);
return secrets[key];
}
async function readAllChunks(readableStream) { async function readAllChunks(readableStream) {
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
@ -5427,8 +5217,6 @@ async function getImageBuffers(zipFilePath) {
}); });
} }
/** /**
* This function extracts the extension information from the manifest file. * This function extracts the extension information from the manifest file.
* @param {string} extensionPath - The path of the extension folder * @param {string} extensionPath - The path of the extension folder
@ -5867,3 +5655,20 @@ app.post('/get_character_assets_list', jsonParser, async (request, response) =>
return response.sendStatus(500); return response.sendStatus(500);
} }
}); });
// Vector storage DB
require('./src/vectors').registerEndpoints(app, jsonParser);
// Chat translation
require('./src/translate').registerEndpoints(app, jsonParser);
// Emotion classification
import('./src/classify.mjs').then(module => {
module.default.registerEndpoints(app, jsonParser);
}).catch(err => {
console.error(err);
});
// Image captioning
import('./src/caption.mjs').then(module => {
module.default.registerEndpoints(app, jsonParser);
}).catch(err => {
console.error(err);
});

70
src/caption.mjs Normal file
View File

@ -0,0 +1,70 @@
import { pipeline, env, RawImage } from 'sillytavern-transformers';
import path from 'path';
import { getConfig } from './util.js';
// Limit the number of threads to 1 to avoid issues on Android
env.backends.onnx.wasm.numThreads = 1;
class PipelineAccessor {
/**
* @type {import("sillytavern-transformers").ImageToTextPipeline}
*/
pipe;
async get() {
if (!this.pipe) {
const cache_dir = path.join(process.cwd(), 'cache');
const model = this.getCaptioningModel();
this.pipe = await pipeline('image-to-text', model, { cache_dir, quantized: true });
}
return this.pipe;
}
getCaptioningModel() {
const DEFAULT_MODEL = 'Xenova/vit-gpt2-image-captioning';
try {
const config = getConfig();
const model = config?.extras?.captioningModel;
return model || DEFAULT_MODEL;
} catch (error) {
console.warn('Failed to read config.conf, using default captioning model.');
return DEFAULT_MODEL;
}
}
}
/**
* @param {import("express").Express} app
* @param {any} jsonParser
*/
function registerEndpoints(app, jsonParser) {
const pipelineAccessor = new PipelineAccessor();
app.post('/api/extra/caption', jsonParser, async (req, res) => {
try {
const { image } = req.body;
// base64 string to blob
const buffer = Buffer.from(image, 'base64');
const byteArray = new Uint8Array(buffer);
const blob = new Blob([byteArray]);
const rawImage = await RawImage.fromBlob(blob);
const pipe = await pipelineAccessor.get();
const result = await pipe(rawImage);
const text = result[0].generated_text;
console.log('Image caption:', text);
return res.json({ caption: text });
} catch (error) {
console.error(error);
return res.sendStatus(500);
}
});
}
export default {
registerEndpoints,
};

87
src/classify.mjs Normal file
View File

@ -0,0 +1,87 @@
import { pipeline, env } from 'sillytavern-transformers';
import path from 'path';
import { getConfig } from './util.js';
// Limit the number of threads to 1 to avoid issues on Android
env.backends.onnx.wasm.numThreads = 1;
class PipelineAccessor {
/**
* @type {import("sillytavern-transformers").TextClassificationPipeline}
*/
pipe;
async get() {
if (!this.pipe) {
const cache_dir = path.join(process.cwd(), 'cache');
const model = this.getClassificationModel();
this.pipe = await pipeline('text-classification', model, { cache_dir, quantized: true });
}
return this.pipe;
}
getClassificationModel() {
const DEFAULT_MODEL = 'Cohee/distilbert-base-uncased-go-emotions-onnx';
try {
const config = getConfig();
const model = config?.extras?.classificationModel;
return model || DEFAULT_MODEL;
} catch (error) {
console.warn('Failed to read config.conf, using default classification model.');
return DEFAULT_MODEL;
}
}
}
/**
* @param {import("express").Express} app
* @param {any} jsonParser
*/
function registerEndpoints(app, jsonParser) {
const cacheObject = {};
const pipelineAccessor = new PipelineAccessor();
app.post('/api/extra/classify/labels', jsonParser, async (req, res) => {
try {
const pipe = await pipelineAccessor.get();
const result = Object.keys(pipe.model.config.label2id);
return res.json({ labels: result });
} catch (error) {
console.error(error);
return res.sendStatus(500);
}
});
app.post('/api/extra/classify', jsonParser, async (req, res) => {
try {
const { text } = req.body;
async function getResult(text) {
if (cacheObject.hasOwnProperty(text)) {
return cacheObject[text];
} else {
const pipe = await pipelineAccessor.get();
const result = await pipe(text, { topk: 5 });
result.sort((a, b) => b.score - a.score);
cacheObject[text] = result;
return result;
}
}
console.log('Classify input:', text);
const result = await getResult(text);
console.log('Classify output:', result);
return res.json({ classification: result });
} catch (error) {
console.error(error);
return res.sendStatus(500);
}
});
}
export default {
registerEndpoints,
};

38
src/local-vectors.js Normal file
View File

@ -0,0 +1,38 @@
require('@tensorflow/tfjs');
const encoder = require('@tensorflow-models/universal-sentence-encoder');
/**
* Lazy loading class for the embedding model.
*/
class EmbeddingModel {
/**
* @type {encoder.UniversalSentenceEncoder} - The embedding model
*/
model;
async get() {
if (!this.model) {
this.model = await encoder.load();
}
return this.model;
}
}
const model = new EmbeddingModel();
/**
* @param {string} text
*/
async function getLocalVector(text) {
const use = await model.get();
const tensor = await use.embed(text);
const vector = Array.from(await tensor.data());
return vector;
}
module.exports = {
getLocalVector,
};

View File

@ -2,12 +2,7 @@
* When applied, this middleware will ensure the request contains the required header for basic authentication and only * When applied, this middleware will ensure the request contains the required header for basic authentication and only
* allow access to the endpoint after successful authentication. * allow access to the endpoint after successful authentication.
*/ */
const { getConfig } = require('./../util.js');
//const {dirname} = require('path');
//const appDir = dirname(require.main.filename);
//const config = require(appDir + '/config.conf');
const path = require('path');
const config = require(path.join(process.cwd(), './config.conf'));
const unauthorizedResponse = (res) => { const unauthorizedResponse = (res) => {
res.set('WWW-Authenticate', 'Basic realm="SillyTavern", charset="UTF-8"'); res.set('WWW-Authenticate', 'Basic realm="SillyTavern", charset="UTF-8"');
@ -15,6 +10,7 @@ const unauthorizedResponse = (res) => {
}; };
const basicAuthMiddleware = function (request, response, callback) { const basicAuthMiddleware = function (request, response, callback) {
const config = getConfig();
const authHeader = request.headers.authorization; const authHeader = request.headers.authorization;
if (!authHeader) { if (!authHeader) {

48
src/openai-vectors.js Normal file
View File

@ -0,0 +1,48 @@
const fetch = require('node-fetch').default;
const { SECRET_KEYS, readSecret } = require('./secrets');
/**
* Gets the vector for the given text from OpenAI ada model
* @param {string} text - The text to get the vector for
* @returns {Promise<number[]>} - The vector for the text
*/
async function getOpenAIVector(text) {
const key = readSecret(SECRET_KEYS.OPENAI);
if (!key) {
console.log('No OpenAI key found');
throw new Error('No OpenAI key found');
}
const response = await fetch('https://api.openai.com/v1/embeddings', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${key}`,
},
body: JSON.stringify({
input: text,
model: 'text-embedding-ada-002',
})
});
if (!response.ok) {
const text = await response.text();
console.log('OpenAI request failed', response.statusText, text);
throw new Error('OpenAI request failed');
}
const data = await response.json();
const vector = data?.data[0]?.embedding;
if (!Array.isArray(vector)) {
console.log('OpenAI response was not an array');
throw new Error('OpenAI response was not an array');
}
return vector;
}
module.exports = {
getOpenAIVector,
};

148
src/secrets.js Normal file
View File

@ -0,0 +1,148 @@
const fs = require('fs');
const path = require('path');
const writeFileAtomicSync = require('write-file-atomic').sync;
const SECRETS_FILE = path.join(process.cwd(), './secrets.json');
const SECRET_KEYS = {
HORDE: 'api_key_horde',
MANCER: 'api_key_mancer',
OPENAI: 'api_key_openai',
NOVEL: 'api_key_novel',
CLAUDE: 'api_key_claude',
DEEPL: 'deepl',
LIBRE: 'libre',
LIBRE_URL: 'libre_url',
OPENROUTER: 'api_key_openrouter',
SCALE: 'api_key_scale',
AI21: 'api_key_ai21',
SCALE_COOKIE: 'scale_cookie',
ONERING_URL: 'oneringtranslator_url',
DEEPLX_URL: 'deeplx_url',
}
/**
* Writes a secret to the secrets file
* @param {string} key Secret key
* @param {string} value Secret value
*/
function writeSecret(key, value) {
if (!fs.existsSync(SECRETS_FILE)) {
const emptyFile = JSON.stringify({});
writeFileAtomicSync(SECRETS_FILE, emptyFile, "utf-8");
}
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf-8');
const secrets = JSON.parse(fileContents);
secrets[key] = value;
writeFileAtomicSync(SECRETS_FILE, JSON.stringify(secrets), "utf-8");
}
/**
* Reads a secret from the secrets file
* @param {string} key Secret key
* @returns {string} Secret value
*/
function readSecret(key) {
if (!fs.existsSync(SECRETS_FILE)) {
return '';
}
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf-8');
const secrets = JSON.parse(fileContents);
return secrets[key];
}
/**
* Reads the secret state from the secrets file
* @returns {object} Secret state
*/
function readSecretState() {
if (!fs.existsSync(SECRETS_FILE)) {
return {};
}
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf8');
const secrets = JSON.parse(fileContents);
const state = {};
for (const key of Object.values(SECRET_KEYS)) {
state[key] = !!secrets[key]; // convert to boolean
}
return state;
}
/**
* Migrates secrets from settings.json to secrets.json
* @param {string} settingsFile Path to settings.json
* @returns {void}
*/
function migrateSecrets(settingsFile) {
if (!fs.existsSync(settingsFile)) {
console.log('Settings file does not exist');
return;
}
try {
let modified = false;
const fileContents = fs.readFileSync(settingsFile, 'utf8');
const settings = JSON.parse(fileContents);
const oaiKey = settings?.api_key_openai;
const hordeKey = settings?.horde_settings?.api_key;
const novelKey = settings?.api_key_novel;
if (typeof oaiKey === 'string') {
console.log('Migrating OpenAI key...');
writeSecret(SECRET_KEYS.OPENAI, oaiKey);
delete settings.api_key_openai;
modified = true;
}
if (typeof hordeKey === 'string') {
console.log('Migrating Horde key...');
writeSecret(SECRET_KEYS.HORDE, hordeKey);
delete settings.horde_settings.api_key;
modified = true;
}
if (typeof novelKey === 'string') {
console.log('Migrating Novel key...');
writeSecret(SECRET_KEYS.NOVEL, novelKey);
delete settings.api_key_novel;
modified = true;
}
if (modified) {
console.log('Writing updated settings.json...');
const settingsContent = JSON.stringify(settings);
writeFileAtomicSync(settingsFile, settingsContent, "utf-8");
}
}
catch (error) {
console.error('Could not migrate secrets file. Proceed with caution.');
}
}
/**
* Reads all secrets from the secrets file
* @returns {Record<string, string> | undefined} Secrets
*/
function getAllSecrets() {
if (!fs.existsSync(SECRETS_FILE)) {
console.log('Secrets file does not exist');
return undefined;
}
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf8');
const secrets = JSON.parse(fileContents);
return secrets;
}
module.exports = {
writeSecret,
readSecret,
readSecretState,
migrateSecrets,
getAllSecrets,
SECRET_KEYS,
};

248
src/translate.js Normal file
View File

@ -0,0 +1,248 @@
const fetch = require('node-fetch').default;
const https = require('https');
const { readSecret, SECRET_KEYS } = require('./secrets');
const { generateRequestUrl, normaliseResponse } = require('google-translate-api-browser');
const DEEPLX_URL_DEFAULT = 'http://127.0.0.1:1188/translate';
const ONERING_URL_DEFAULT = 'http://127.0.0.1:4990/translate';
/**
* @param {import("express").Express} app
* @param {any} jsonParser
*/
function registerEndpoints(app, jsonParser) {
app.post('/api/translate/libre', jsonParser, async (request, response) => {
const key = readSecret(SECRET_KEYS.LIBRE);
const url = readSecret(SECRET_KEYS.LIBRE_URL);
if (!url) {
console.log('LibreTranslate URL is not configured.');
return response.sendStatus(401);
}
const text = request.body.text;
const lang = request.body.lang;
if (!text || !lang) {
return response.sendStatus(400);
}
console.log('Input text: ' + text);
try {
const result = await fetch(url, {
method: "POST",
body: JSON.stringify({
q: text,
source: "auto",
target: lang,
format: "text",
api_key: key
}),
headers: { "Content-Type": "application/json" }
});
if (!result.ok) {
const error = await result.text();
console.log('LibreTranslate error: ', result.statusText, error);
return response.sendStatus(result.status);
}
const json = await result.json();
console.log('Translated text: ' + json.translatedText);
return response.send(json.translatedText);
} catch (error) {
console.log("Translation error: " + error.message);
return response.sendStatus(500);
}
});
app.post('/api/translate/google', jsonParser, async (request, response) => {
const text = request.body.text;
const lang = request.body.lang;
if (!text || !lang) {
return response.sendStatus(400);
}
console.log('Input text: ' + text);
const url = generateRequestUrl(text, { to: lang });
https.get(url, (resp) => {
let data = '';
resp.on('data', (chunk) => {
data += chunk;
});
resp.on('end', () => {
const result = normaliseResponse(JSON.parse(data));
console.log('Translated text: ' + result.text);
return response.send(result.text);
});
}).on("error", (err) => {
console.log("Translation error: " + err.message);
return response.sendStatus(500);
});
});
app.post('/api/translate/deepl', jsonParser, async (request, response) => {
const key = readSecret(SECRET_KEYS.DEEPL);
if (!key) {
return response.sendStatus(401);
}
const text = request.body.text;
const lang = request.body.lang;
if (!text || !lang) {
return response.sendStatus(400);
}
console.log('Input text: ' + text);
const params = new URLSearchParams();
params.append('text', text);
params.append('target_lang', lang);
try {
const result = await fetch('https://api-free.deepl.com/v2/translate', {
method: 'POST',
body: params,
headers: {
'Accept': 'application/json',
'Authorization': `DeepL-Auth-Key ${key}`,
'Content-Type': 'application/x-www-form-urlencoded',
},
timeout: 0,
});
if (!result.ok) {
const error = await result.text();
console.log('DeepL error: ', result.statusText, error);
return response.sendStatus(result.status);
}
const json = await result.json();
console.log('Translated text: ' + json.translations[0].text);
return response.send(json.translations[0].text);
} catch (error) {
console.log("Translation error: " + error.message);
return response.sendStatus(500);
}
});
app.post('/api/translate/onering', jsonParser, async (request, response) => {
const secretUrl = readSecret(SECRET_KEYS.ONERING_URL);
const url = secretUrl || ONERING_URL_DEFAULT;
if (!url) {
console.log('OneRing URL is not configured.');
return response.sendStatus(401);
}
if (!secretUrl && url === ONERING_URL_DEFAULT) {
console.log('OneRing URL is using default value.', ONERING_URL_DEFAULT);
}
const text = request.body.text;
const from_lang = request.body.from_lang;
const to_lang = request.body.to_lang;
if (!text || !from_lang || !to_lang) {
return response.sendStatus(400);
}
const params = new URLSearchParams();
params.append('text', text);
params.append('from_lang', from_lang);
params.append('to_lang', to_lang);
console.log('Input text: ' + text);
try {
const fetchUrl = new URL(url);
fetchUrl.search = params.toString();
const result = await fetch(fetchUrl, {
method: 'GET',
timeout: 0,
});
if (!result.ok) {
const error = await result.text();
console.log('OneRing error: ', result.statusText, error);
return response.sendStatus(result.status);
}
const data = await result.json();
console.log('Translated text: ' + data.result);
return response.send(data.result);
} catch (error) {
console.log("Translation error: " + error.message);
return response.sendStatus(500);
}
});
app.post('/api/translate/deeplx', jsonParser, async (request, response) => {
const secretUrl = readSecret(SECRET_KEYS.DEEPLX_URL);
const url = secretUrl || DEEPLX_URL_DEFAULT;
if (!url) {
console.log('DeepLX URL is not configured.');
return response.sendStatus(401);
}
if (!secretUrl && url === DEEPLX_URL_DEFAULT) {
console.log('DeepLX URL is using default value.', DEEPLX_URL_DEFAULT);
}
const text = request.body.text;
const lang = request.body.lang;
if (!text || !lang) {
return response.sendStatus(400);
}
console.log('Input text: ' + text);
try {
const result = await fetch(url, {
method: 'POST',
body: JSON.stringify({
text: text,
source_lang: 'auto',
target_lang: lang,
}),
headers: {
'Accept': 'application/json',
'Content-Type': 'application/json',
},
timeout: 0,
});
if (!result.ok) {
const error = await result.text();
console.log('DeepLX error: ', result.statusText, error);
return response.sendStatus(result.status);
}
const json = await result.json();
console.log('Translated text: ' + json.data);
return response.send(json.data);
} catch (error) {
console.log("DeepLX translation error: " + error.message);
return response.sendStatus(500);
}
});
}
module.exports = {
registerEndpoints,
};

10
src/util.js Normal file
View File

@ -0,0 +1,10 @@
const path = require('path');
function getConfig() {
const config = require(path.join(process.cwd(), './config.conf'));
return config;
}
module.exports = {
getConfig,
};

221
src/vectors.js Normal file
View File

@ -0,0 +1,221 @@
const express = require('express');
const vectra = require('vectra');
const path = require('path');
const sanitize = require('sanitize-filename');
/**
* Gets the vector for the given text from the given source.
* @param {string} source - The source of the vector
* @param {string} text - The text to get the vector for
* @returns {Promise<number[]>} - The vector for the text
*/
async function getVector(source, text) {
switch (source) {
case 'local':
return require('./local-vectors').getLocalVector(text);
case 'openai':
return require('./openai-vectors').getOpenAIVector(text);
}
throw new Error(`Unknown vector source ${source}`);
}
/**
* Gets the index for the vector collection
* @param {string} collectionId - The collection ID
* @param {string} source - The source of the vector
* @param {boolean} create - Whether to create the index if it doesn't exist
* @returns {Promise<vectra.LocalIndex>} - The index for the collection
*/
async function getIndex(collectionId, source, create = true) {
const index = new vectra.LocalIndex(path.join(process.cwd(), 'vectors', sanitize(source), sanitize(collectionId)));
if (create && !await index.isIndexCreated()) {
await index.createIndex();
}
return index;
}
/**
* Inserts items into the vector collection
* @param {string} collectionId - The collection ID
* @param {string} source - The source of the vector
* @param {{ hash: number; text: string; }[]} items - The items to insert
*/
async function insertVectorItems(collectionId, source, items) {
const index = await getIndex(collectionId, source);
await index.beginUpdate();
for (const item of items) {
const text = item.text;
const hash = item.hash;
const vector = await getVector(source, text);
await index.upsertItem({ vector: vector, metadata: { hash, text } });
}
await index.endUpdate();
}
/**
* Gets the hashes of the items in the vector collection
* @param {string} collectionId - The collection ID
* @param {string} source - The source of the vector
* @returns {Promise<number[]>} - The hashes of the items in the collection
*/
async function getSavedHashes(collectionId, source) {
const index = await getIndex(collectionId, source);
const items = await index.listItems();
const hashes = items.map(x => Number(x.metadata.hash));
return hashes;
}
/**
* Deletes items from the vector collection by hash
* @param {string} collectionId - The collection ID
* @param {string} source - The source of the vector
* @param {number[]} hashes - The hashes of the items to delete
*/
async function deleteVectorItems(collectionId, source, hashes) {
const index = await getIndex(collectionId, source);
const items = await index.listItemsByMetadata({ hash: { '$in': hashes } });
await index.beginUpdate();
for (const item of items) {
await index.deleteItem(item.id);
}
await index.endUpdate();
}
/**
* Gets the hashes of the items in the vector collection that match the search text
* @param {string} collectionId - The collection ID
* @param {string} source - The source of the vector
* @param {string} searchText - The text to search for
* @param {number} topK - The number of results to return
* @returns {Promise<number[]>} - The hashes of the items that match the search text
*/
async function queryCollection(collectionId, source, searchText, topK) {
const index = await getIndex(collectionId, source);
const vector = await getVector(source, searchText);
const result = await index.queryItems(vector, topK);
const hashes = result.map(x => Number(x.item.metadata.hash));
return hashes;
}
/**
* Registers the endpoints for the vector API
* @param {express.Express} app - Express app
* @param {any} jsonParser - Express JSON parser
*/
async function registerEndpoints(app, jsonParser) {
app.post('/api/vector/query', jsonParser, async (req, res) => {
try {
if (!req.body.collectionId || !req.body.searchText) {
return res.sendStatus(400);
}
const collectionId = String(req.body.collectionId);
const searchText = String(req.body.searchText);
const topK = Number(req.body.topK) || 10;
const source = String(req.body.source) || 'local';
const results = await queryCollection(collectionId, source, searchText, topK);
return res.json(results);
} catch (error) {
console.error(error);
return res.sendStatus(500);
}
});
app.post('/api/vector/insert', jsonParser, async (req, res) => {
try {
if (!Array.isArray(req.body.items) || !req.body.collectionId) {
return res.sendStatus(400);
}
const collectionId = String(req.body.collectionId);
const items = req.body.items.map(x => ({ hash: x.hash, text: x.text }));
const source = String(req.body.source) || 'local';
await insertVectorItems(collectionId, source, items);
return res.sendStatus(200);
} catch (error) {
console.error(error);
return res.sendStatus(500);
}
});
app.post('/api/vector/list', jsonParser, async (req, res) => {
try {
if (!req.body.collectionId) {
return res.sendStatus(400);
}
const collectionId = String(req.body.collectionId);
const source = String(req.body.source) || 'local';
const hashes = await getSavedHashes(collectionId, source);
return res.json(hashes);
} catch (error) {
console.error(error);
return res.sendStatus(500);
}
});
app.post('/api/vector/delete', jsonParser, async (req, res) => {
try {
if (!Array.isArray(req.body.hashes) || !req.body.collectionId) {
return res.sendStatus(400);
}
const collectionId = String(req.body.collectionId);
const hashes = req.body.hashes.map(x => Number(x));
const source = String(req.body.source) || 'local';
await deleteVectorItems(collectionId, source, hashes);
return res.sendStatus(200);
} catch (error) {
console.error(error);
return res.sendStatus(500);
}
});
app.post('/api/vector/purge', jsonParser, async (req, res) => {
try {
if (!req.body.collectionId) {
return res.sendStatus(400);
}
const collectionId = String(req.body.collectionId);
const sources = ['local', 'openai'];
for (const source of sources) {
const index = await getIndex(collectionId, source, false);
const exists = await index.isIndexCreated();
if (!exists) {
continue;
}
const path = index.folderPath;
await index.deleteIndex();
console.log(`Deleted vector index at ${path}`);
}
return res.sendStatus(200);
} catch (error) {
console.error(error);
return res.sendStatus(500);
}
});
}
module.exports = { registerEndpoints };