diff --git a/default/config.yaml b/default/config.yaml index 48fa68cec..04781eb19 100644 --- a/default/config.yaml +++ b/default/config.yaml @@ -25,6 +25,9 @@ autorun: true disableThumbnails: false # Thumbnail quality (0-100) thumbnailsQuality: 95 +# Generate avatar thumbnails as PNG instead of JPG (preserves transparency but increases filesize by about 100%) +# Changing this only affects new thumbnails. To recreate the old ones, clear out your ST/thumbnails/ folder. +avatarThumbnailsPng: false # Allow secret keys exposure via API allowKeysExposure: false # Skip new default content checks diff --git a/package-lock.json b/package-lock.json index 1ee5316bb..282dc6ca3 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "sillytavern", - "version": "1.11.1", + "version": "1.11.2", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "sillytavern", - "version": "1.11.1", + "version": "1.11.2", "hasInstallScript": true, "license": "AGPL-3.0", "dependencies": { diff --git a/package.json b/package.json index 40eddd76b..1c6f278de 100644 --- a/package.json +++ b/package.json @@ -51,7 +51,7 @@ "type": "git", "url": "https://github.com/SillyTavern/SillyTavern.git" }, - "version": "1.11.1", + "version": "1.11.2", "scripts": { "start": "node server.js", "start-multi": "node server.js --disableCsrf", diff --git a/public/index.html b/public/index.html index 8a8982ebd..28d41d52d 100644 --- a/public/index.html +++ b/public/index.html @@ -1495,7 +1495,18 @@ Add character names
- Send names in the ChatML objects. Helps the model to associate messages with characters. + Send names in the message objects. Helps the model to associate messages with characters. +
+ +
+ +
+ + Continue sends the last message as assistant role instead of system message with instruction. +
diff --git a/public/script.js b/public/script.js index d38432a54..5cae61df4 100644 --- a/public/script.js +++ b/public/script.js @@ -9581,6 +9581,10 @@ jQuery(async function () { valueBeforeManualInput = $(this).val(); console.log(valueBeforeManualInput); }) + .on('change', function (e) { + e.target.focus(); + e.target.dispatchEvent(new Event('keyup')); + }) .on('keydown', function (e) { const masterSelector = '#' + $(this).data('for'); const masterElement = $(masterSelector); diff --git a/public/scripts/extensions.js b/public/scripts/extensions.js index 120791842..5d69673fa 100644 --- a/public/scripts/extensions.js +++ b/public/scripts/extensions.js @@ -110,6 +110,7 @@ const extension_settings = { sd: { prompts: {}, character_prompts: {}, + character_negative_prompts: {}, }, chromadb: {}, translate: {}, diff --git a/public/scripts/extensions/stable-diffusion/index.js b/public/scripts/extensions/stable-diffusion/index.js index 3e8072711..e8bc8f1e2 100644 --- a/public/scripts/extensions/stable-diffusion/index.js +++ b/public/scripts/extensions/stable-diffusion/index.js @@ -351,6 +351,10 @@ async function loadSettings() { extension_settings.sd.character_prompts = {}; } + if (extension_settings.sd.character_negative_prompts === undefined) { + extension_settings.sd.character_negative_prompts = {}; + } + if (!Array.isArray(extension_settings.sd.styles)) { extension_settings.sd.styles = defaultStyles; } @@ -575,6 +579,7 @@ function onChatChanged() { $('#sd_character_prompt_block').show(); const key = getCharaFilename(this_chid); $('#sd_character_prompt').val(key ? (extension_settings.sd.character_prompts[key] || '') : ''); + $('#sd_character_negative_prompt').val(key ? (extension_settings.sd.character_negative_prompts[key] || '') : ''); } function onCharacterPromptInput() { @@ -584,6 +589,13 @@ function onCharacterPromptInput() { saveSettingsDebounced(); } +function onCharacterNegativePromptInput() { + const key = getCharaFilename(this_chid); + extension_settings.sd.character_negative_prompts[key] = $('#sd_character_negative_prompt').val(); + resetScrollHeight($(this)); + saveSettingsDebounced(); +} + function getCharacterPrefix() { if (!this_chid || selected_group) { return ''; @@ -598,6 +610,20 @@ function getCharacterPrefix() { return ''; } +function getCharacterNegativePrefix() { + if (!this_chid || selected_group) { + return ''; + } + + const key = getCharaFilename(this_chid); + + if (key) { + return extension_settings.sd.character_negative_prompts[key] || ''; + } + + return ''; +} + /** * Combines two prompt prefixes into one. * @param {string} str1 Base string @@ -1885,34 +1911,38 @@ async function sendGenerationRequest(generationType, prompt, characterName = nul const prefixedPrompt = combinePrefixes(prefix, prompt, '{prompt}'); + const negativePrompt = noCharPrefix.includes(generationType) + ? extension_settings.sd.negative_prompt + : combinePrefixes(extension_settings.sd.negative_prompt, getCharacterNegativePrefix()); + let result = { format: '', data: '' }; const currentChatId = getCurrentChatId(); try { switch (extension_settings.sd.source) { case sources.extras: - result = await generateExtrasImage(prefixedPrompt); + result = await generateExtrasImage(prefixedPrompt, negativePrompt); break; case sources.horde: - result = await generateHordeImage(prefixedPrompt); + result = await generateHordeImage(prefixedPrompt, negativePrompt); break; case sources.vlad: - result = await generateAutoImage(prefixedPrompt); + result = await generateAutoImage(prefixedPrompt, negativePrompt); break; case sources.auto: - result = await generateAutoImage(prefixedPrompt); + result = await generateAutoImage(prefixedPrompt, negativePrompt); break; case sources.novel: - result = await generateNovelImage(prefixedPrompt); + result = await generateNovelImage(prefixedPrompt, negativePrompt); break; case sources.openai: result = await generateOpenAiImage(prefixedPrompt); break; case sources.comfy: - result = await generateComfyImage(prefixedPrompt); + result = await generateComfyImage(prefixedPrompt, negativePrompt); break; case sources.togetherai: - result = await generateTogetherAIImage(prefixedPrompt); + result = await generateTogetherAIImage(prefixedPrompt, negativePrompt); break; } @@ -1936,13 +1966,13 @@ async function sendGenerationRequest(generationType, prompt, characterName = nul callback ? callback(prompt, base64Image, generationType) : sendMessage(prompt, base64Image, generationType); } -async function generateTogetherAIImage(prompt) { +async function generateTogetherAIImage(prompt, negativePrompt) { const result = await fetch('/api/sd/together/generate', { method: 'POST', headers: getRequestHeaders(), body: JSON.stringify({ prompt: prompt, - negative_prompt: extension_settings.sd.negative_prompt, + negative_prompt: negativePrompt, model: extension_settings.sd.model, steps: extension_settings.sd.steps, width: extension_settings.sd.width, @@ -1963,9 +1993,10 @@ async function generateTogetherAIImage(prompt) { * Generates an "extras" image using a provided prompt and other settings. * * @param {string} prompt - The main instruction used to guide the image generation. + * @param {string} negativePrompt - The instruction used to restrict the image generation. * @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete. */ -async function generateExtrasImage(prompt) { +async function generateExtrasImage(prompt, negativePrompt) { const url = new URL(getApiUrl()); url.pathname = '/api/image'; const result = await doExtrasFetch(url, { @@ -1980,7 +2011,7 @@ async function generateExtrasImage(prompt) { scale: extension_settings.sd.scale, width: extension_settings.sd.width, height: extension_settings.sd.height, - negative_prompt: extension_settings.sd.negative_prompt, + negative_prompt: negativePrompt, restore_faces: !!extension_settings.sd.restore_faces, enable_hr: !!extension_settings.sd.enable_hr, karras: !!extension_settings.sd.horde_karras, @@ -2004,9 +2035,10 @@ async function generateExtrasImage(prompt) { * Generates a "horde" image using the provided prompt and configuration settings. * * @param {string} prompt - The main instruction used to guide the image generation. + * @param {string} negativePrompt - The instruction used to restrict the image generation. * @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete. */ -async function generateHordeImage(prompt) { +async function generateHordeImage(prompt, negativePrompt) { const result = await fetch('/api/horde/generate-image', { method: 'POST', headers: getRequestHeaders(), @@ -2017,7 +2049,7 @@ async function generateHordeImage(prompt) { scale: extension_settings.sd.scale, width: extension_settings.sd.width, height: extension_settings.sd.height, - negative_prompt: extension_settings.sd.negative_prompt, + negative_prompt: negativePrompt, model: extension_settings.sd.model, nsfw: extension_settings.sd.horde_nsfw, restore_faces: !!extension_settings.sd.restore_faces, @@ -2039,16 +2071,17 @@ async function generateHordeImage(prompt) { * Generates an image in SD WebUI API using the provided prompt and configuration settings. * * @param {string} prompt - The main instruction used to guide the image generation. + * @param {string} negativePrompt - The instruction used to restrict the image generation. * @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete. */ -async function generateAutoImage(prompt) { +async function generateAutoImage(prompt, negativePrompt) { const result = await fetch('/api/sd/generate', { method: 'POST', headers: getRequestHeaders(), body: JSON.stringify({ ...getSdRequestBody(), prompt: prompt, - negative_prompt: extension_settings.sd.negative_prompt, + negative_prompt: negativePrompt, sampler_name: extension_settings.sd.sampler, steps: extension_settings.sd.steps, cfg_scale: extension_settings.sd.scale, @@ -2081,9 +2114,10 @@ async function generateAutoImage(prompt) { * Generates an image in NovelAI API using the provided prompt and configuration settings. * * @param {string} prompt - The main instruction used to guide the image generation. + * @param {string} negativePrompt - The instruction used to restrict the image generation. * @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete. */ -async function generateNovelImage(prompt) { +async function generateNovelImage(prompt, negativePrompt) { const { steps, width, height } = getNovelParams(); const result = await fetch('/api/novelai/generate-image', { @@ -2097,7 +2131,7 @@ async function generateNovelImage(prompt) { scale: extension_settings.sd.scale, width: width, height: height, - negative_prompt: extension_settings.sd.negative_prompt, + negative_prompt: negativePrompt, upscale_ratio: extension_settings.sd.novel_upscale_ratio, }), }); @@ -2225,11 +2259,11 @@ async function generateOpenAiImage(prompt) { * Generates an image in ComfyUI using the provided prompt and configuration settings. * * @param {string} prompt - The main instruction used to guide the image generation. + * @param {string} negativePrompt - The instruction used to restrict the image generation. * @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete. */ -async function generateComfyImage(prompt) { +async function generateComfyImage(prompt, negativePrompt) { const placeholders = [ - 'negative_prompt', 'model', 'vae', 'sampler', @@ -2252,6 +2286,7 @@ async function generateComfyImage(prompt) { toastr.error(`Failed to load workflow.\n\n${text}`); } let workflow = (await workflowResponse.json()).replace('"%prompt%"', JSON.stringify(prompt)); + workflow = (await workflowResponse.json()).replace('"%negative_prompt%"', JSON.stringify(negativePrompt)); workflow = workflow.replace('"%seed%"', JSON.stringify(Math.round(Math.random() * Number.MAX_SAFE_INTEGER))); placeholders.forEach(ph => { workflow = workflow.replace(`"%${ph}%"`, JSON.stringify(extension_settings.sd[ph])); @@ -2629,6 +2664,7 @@ jQuery(async () => { $('#sd_enable_hr').on('input', onHighResFixInput); $('#sd_refine_mode').on('input', onRefineModeInput); $('#sd_character_prompt').on('input', onCharacterPromptInput); + $('#sd_character_negative_prompt').on('input', onCharacterNegativePromptInput); $('#sd_auto_validate').on('click', validateAutoUrl); $('#sd_auto_url').on('input', onAutoUrlInput); $('#sd_auto_auth').on('input', onAutoAuthInput); @@ -2661,6 +2697,7 @@ jQuery(async () => { initScrollHeight($('#sd_prompt_prefix')); initScrollHeight($('#sd_negative_prompt')); initScrollHeight($('#sd_character_prompt')); + initScrollHeight($('#sd_character_negative_prompt')); }); for (const [key, value] of Object.entries(resolutionOptions)) { diff --git a/public/scripts/extensions/stable-diffusion/settings.html b/public/scripts/extensions/stable-diffusion/settings.html index e42a24348..1ec94f2e2 100644 --- a/public/scripts/extensions/stable-diffusion/settings.html +++ b/public/scripts/extensions/stable-diffusion/settings.html @@ -208,6 +208,9 @@ Won't be used in groups. + + Won't be used in groups. +
diff --git a/public/scripts/horde.js b/public/scripts/horde.js index bd4322a69..2581d522a 100644 --- a/public/scripts/horde.js +++ b/public/scripts/horde.js @@ -2,7 +2,6 @@ import { saveSettingsDebounced, callPopup, setGenerationProgress, - CLIENT_VERSION, getRequestHeaders, max_context, amount_gen, @@ -34,19 +33,96 @@ let horde_settings = { const MAX_RETRIES = 480; const CHECK_INTERVAL = 2500; const MIN_LENGTH = 16; -const getRequestArgs = () => ({ - method: 'GET', - headers: { - 'Client-Agent': CLIENT_VERSION, - }, -}); -async function getWorkers(workerType) { - const response = await fetch('https://horde.koboldai.net/api/v2/workers?type=text', getRequestArgs()); +/** + * Gets the available workers from Horde. + * @param {boolean} force Do a force refresh of the workers + * @returns {Promise} Array of workers + */ +async function getWorkers(force) { + const response = await fetch('/api/horde/text-workers', { + method: 'POST', + headers: getRequestHeaders(), + body: JSON.stringify({ force }), + }); const data = await response.json(); return data; } +/** + * Gets the available models from Horde. + * @param {boolean} force Do a force refresh of the models + * @returns {Promise} Array of models + */ +async function getModels(force) { + const response = await fetch('/api/horde/text-models', { + method: 'POST', + headers: getRequestHeaders(), + body: JSON.stringify({ force }), + }); + const data = await response.json(); + return data; +} + +/** + * Gets the status of a Horde task. + * @param {string} taskId Task ID + * @returns {Promise} Task status + */ +async function getTaskStatus(taskId) { + const response = await fetch('/api/horde/task-status', { + method: 'POST', + headers: getRequestHeaders(), + body: JSON.stringify({ taskId }), + }); + + if (!response.ok) { + throw new Error(`Failed to get task status: ${response.statusText}`); + } + + const data = await response.json(); + return data; +} + +/** + * Cancels a Horde task. + * @param {string} taskId Task ID + */ +async function cancelTask(taskId) { + const response = await fetch('/api/horde/cancel-task', { + method: 'POST', + headers: getRequestHeaders(), + body: JSON.stringify({ taskId }), + }); + + if (!response.ok) { + throw new Error(`Failed to cancel task: ${response.statusText}`); + } +} + +/** + * Checks if Horde is online. + * @returns {Promise} True if Horde is online, false otherwise + */ +async function checkHordeStatus() { + try { + const response = await fetch('/api/horde/status', { + method: 'POST', + headers: getRequestHeaders(), + }); + + if (!response.ok) { + return false; + } + + const data = await response.json(); + return data.ok; + } catch (error) { + console.error(error); + return false; + } +} + function validateHordeModel() { let selectedModels = models.filter(m => horde_settings.models.includes(m.name)); @@ -60,7 +136,7 @@ function validateHordeModel() { async function adjustHordeGenerationParams(max_context_length, max_length) { console.log(max_context_length, max_length); - const workers = await getWorkers(); + const workers = await getWorkers(false); let maxContextLength = max_context_length; let maxLength = max_length; let availableWorkers = []; @@ -126,10 +202,7 @@ async function generateHorde(prompt, params, signal, reportProgress) { const response = await fetch('/api/horde/generate-text', { method: 'POST', - headers: { - ...getRequestHeaders(), - 'Client-Agent': CLIENT_VERSION, - }, + headers: getRequestHeaders(), body: JSON.stringify(payload), }); @@ -146,24 +219,17 @@ async function generateHorde(prompt, params, signal, reportProgress) { throw new Error(`Horde generation failed: ${reason}`); } - const task_id = responseJson.id; + const taskId = responseJson.id; let queue_position_first = null; - console.log(`Horde task id = ${task_id}`); + console.log(`Horde task id = ${taskId}`); for (let retryNumber = 0; retryNumber < MAX_RETRIES; retryNumber++) { if (signal.aborted) { - fetch(`https://horde.koboldai.net/api/v2/generate/text/status/${task_id}`, { - method: 'DELETE', - headers: { - 'Client-Agent': CLIENT_VERSION, - }, - }); + cancelTask(taskId); throw new Error('Request aborted'); } - const statusCheckResponse = await fetch(`https://horde.koboldai.net/api/v2/generate/text/status/${task_id}`, getRequestArgs()); - - const statusCheckJson = await statusCheckResponse.json(); + const statusCheckJson = await getTaskStatus(taskId); console.log(statusCheckJson); if (statusCheckJson.faulted === true) { @@ -202,18 +268,13 @@ async function generateHorde(prompt, params, signal, reportProgress) { throw new Error('Horde timeout'); } -async function checkHordeStatus() { - const response = await fetch('https://horde.koboldai.net/api/v2/status/heartbeat', getRequestArgs()); - return response.ok; -} - -async function getHordeModels() { +/** + * Displays the available models in the Horde model selection dropdown. + * @param {boolean} force Force refresh of the models + */ +async function getHordeModels(force) { $('#horde_model').empty(); - const response = await fetch('https://horde.koboldai.net/api/v2/status/models?type=text', getRequestArgs()); - models = await response.json(); - models.sort((a, b) => { - return b.performance - a.performance; - }); + models = (await getModels(force)).sort((a, b) => b.performance - a.performance); for (const model of models) { const option = document.createElement('option'); option.value = model.name; @@ -299,7 +360,7 @@ jQuery(function () { await writeSecret(SECRET_KEYS.HORDE, key); }); - $('#horde_refresh').on('click', getHordeModels); + $('#horde_refresh').on('click', () => getHordeModels(true)); $('#horde_kudos').on('click', showKudos); // Not needed on mobile diff --git a/public/scripts/openai.js b/public/scripts/openai.js index c8821eda8..35f232dcb 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -239,6 +239,7 @@ const default_settings = { squash_system_messages: false, image_inlining: false, bypass_status_check: false, + continue_prefill: false, seed: -1, }; @@ -302,6 +303,7 @@ const oai_settings = { squash_system_messages: false, image_inlining: false, bypass_status_check: false, + continue_prefill: false, seed: -1, }; @@ -660,12 +662,20 @@ async function populateChatHistory(messages, prompts, chatCompletion, type = nul let continueMessage = null; const instruct = isOpenRouterWithInstruct(); if (type === 'continue' && cyclePrompt && !instruct) { - const continuePrompt = new Prompt({ - identifier: 'continueNudge', - role: 'system', - content: oai_settings.continue_nudge_prompt.replace('{{lastChatMessage}}', cyclePrompt), - system_prompt: true, - }); + const promptObject = oai_settings.continue_prefill ? + { + identifier: 'continueNudge', + role: 'assistant', + content: cyclePrompt, + system_prompt: true, + } : + { + identifier: 'continueNudge', + role: 'system', + content: oai_settings.continue_nudge_prompt.replace('{{lastChatMessage}}', cyclePrompt), + system_prompt: true, + }; + const continuePrompt = new Prompt(promptObject); const preparedPrompt = promptManager.preparePrompt(continuePrompt); continueMessage = Message.fromPrompt(preparedPrompt); chatCompletion.reserveBudget(continueMessage); @@ -2376,6 +2386,7 @@ function loadOpenAISettings(data, settings) { oai_settings.new_example_chat_prompt = settings.new_example_chat_prompt ?? default_settings.new_example_chat_prompt; oai_settings.continue_nudge_prompt = settings.continue_nudge_prompt ?? default_settings.continue_nudge_prompt; oai_settings.squash_system_messages = settings.squash_system_messages ?? default_settings.squash_system_messages; + oai_settings.continue_prefill = settings.continue_prefill ?? default_settings.continue_prefill; if (settings.wrap_in_quotes !== undefined) oai_settings.wrap_in_quotes = !!settings.wrap_in_quotes; if (settings.names_in_completion !== undefined) oai_settings.names_in_completion = !!settings.names_in_completion; @@ -2428,6 +2439,7 @@ function loadOpenAISettings(data, settings) { $('#openrouter_force_instruct').prop('checked', oai_settings.openrouter_force_instruct); $('#openrouter_group_models').prop('checked', oai_settings.openrouter_group_models); $('#squash_system_messages').prop('checked', oai_settings.squash_system_messages); + $('#continue_prefill').prop('checked', oai_settings.continue_prefill); if (settings.impersonation_prompt !== undefined) oai_settings.impersonation_prompt = settings.impersonation_prompt; $('#impersonation_prompt_textarea').val(oai_settings.impersonation_prompt); @@ -2593,6 +2605,10 @@ async function saveOpenAIPreset(name, settings, triggerUi = true) { ai21_model: settings.ai21_model, mistralai_model: settings.mistralai_model, custom_model: settings.custom_model, + custom_url: settings.custom_url, + custom_include_body: settings.custom_include_body, + custom_exclude_body: settings.custom_exclude_body, + custom_include_headers: settings.custom_include_headers, google_model: settings.google_model, temperature: settings.temp_openai, frequency_penalty: settings.freq_pen_openai, @@ -2634,6 +2650,8 @@ async function saveOpenAIPreset(name, settings, triggerUi = true) { use_alt_scale: settings.use_alt_scale, squash_system_messages: settings.squash_system_messages, image_inlining: settings.image_inlining, + bypass_status_check: settings.bypass_status_check, + continue_prefill: settings.continue_prefill, seed: settings.seed, }; @@ -3004,6 +3022,7 @@ function onSettingsPresetChange() { use_alt_scale: ['#use_alt_scale', 'use_alt_scale', true], squash_system_messages: ['#squash_system_messages', 'squash_system_messages', true], image_inlining: ['#openai_image_inlining', 'image_inlining', true], + continue_prefill: ['#continue_prefill', 'continue_prefill', true], seed: ['#seed_openai', 'seed', false], }; @@ -3584,17 +3603,17 @@ function onCustomizeParametersClick() { `); - template.find('#custom_include_body').val(oai_settings.custom_include_body).on('input', function() { + template.find('#custom_include_body').val(oai_settings.custom_include_body).on('input', function () { oai_settings.custom_include_body = String($(this).val()); saveSettingsDebounced(); }); - template.find('#custom_exclude_body').val(oai_settings.custom_exclude_body).on('input', function() { + template.find('#custom_exclude_body').val(oai_settings.custom_exclude_body).on('input', function () { oai_settings.custom_exclude_body = String($(this).val()); saveSettingsDebounced(); }); - template.find('#custom_include_headers').val(oai_settings.custom_include_headers).on('input', function() { + template.find('#custom_include_headers').val(oai_settings.custom_include_headers).on('input', function () { oai_settings.custom_include_headers = String($(this).val()); saveSettingsDebounced(); }); @@ -3928,6 +3947,11 @@ $(document).ready(async function () { saveSettingsDebounced(); }); + $('#continue_prefill').on('input', function () { + oai_settings.continue_prefill = !!$(this).prop('checked'); + saveSettingsDebounced(); + }); + $('#seed_openai').on('input', function () { oai_settings.seed = Number($(this).val()); saveSettingsDebounced(); diff --git a/server.js b/server.js index 6124fb1cd..bebcc7bb0 100644 --- a/server.js +++ b/server.js @@ -621,8 +621,13 @@ const setupTasks = async function () { await loadTokenizers(); await statsEndpoint.init(); - const exitProcess = () => { + const cleanupPlugins = await loadPlugins(); + + const exitProcess = async () => { statsEndpoint.onExit(); + if (typeof cleanupPlugins === 'function') { + await cleanupPlugins(); + } process.exit(); }; @@ -634,7 +639,6 @@ const setupTasks = async function () { exitProcess(); }); - await loadPlugins(); console.log('Launching...'); @@ -647,13 +651,19 @@ const setupTasks = async function () { } }; +/** + * Loads server plugins from a directory. + * @returns {Promise} Function to be run on server exit + */ async function loadPlugins() { try { const pluginDirectory = path.join(serverDirectory, 'plugins'); const loader = require('./src/plugin-loader'); - await loader.loadPlugins(app, pluginDirectory); + const cleanupPlugins = await loader.loadPlugins(app, pluginDirectory); + return cleanupPlugins; } catch { console.log('Plugin loading failed.'); + return () => {}; } } diff --git a/src/endpoints/horde.js b/src/endpoints/horde.js index e26b74436..a0d702a62 100644 --- a/src/endpoints/horde.js +++ b/src/endpoints/horde.js @@ -1,20 +1,30 @@ const fetch = require('node-fetch').default; const express = require('express'); const AIHorde = require('../ai_horde'); -const { getVersion, delay } = require('../util'); +const { getVersion, delay, Cache } = require('../util'); const { readSecret, SECRET_KEYS } = require('./secrets'); const { jsonParser } = require('../express-common'); const ANONYMOUS_KEY = '0000000000'; +const cache = new Cache(60 * 1000); +const router = express.Router(); + +/** + * Returns the AIHorde client agent. + * @returns {Promise} AIHorde client agent + */ +async function getClientAgent() { + const version = await getVersion(); + return version?.agent || 'SillyTavern:UNKNOWN:Cohee#1207'; +} /** * Returns the AIHorde client. * @returns {Promise} AIHorde client */ async function getHordeClient() { - const version = await getVersion(); const ai_horde = new AIHorde({ - client_agent: version?.agent || 'SillyTavern:UNKNOWN:Cohee#1207', + client_agent: await getClientAgent(), }); return ai_horde; } @@ -46,11 +56,112 @@ function sanitizeHordeImagePrompt(prompt) { return prompt; } -const router = express.Router(); +router.post('/text-workers', jsonParser, async (request, response) => { + try { + const cachedWorkers = cache.get('workers'); + + if (cachedWorkers && !request.body.force) { + return response.send(cachedWorkers); + } + + const agent = await getClientAgent(); + const fetchResult = await fetch('https://horde.koboldai.net/api/v2/workers?type=text', { + headers: { + 'Client-Agent': agent, + }, + }); + const data = await fetchResult.json(); + cache.set('workers', data); + return response.send(data); + } catch (error) { + console.error(error); + response.sendStatus(500); + } +}); + +router.post('/text-models', jsonParser, async (request, response) => { + try { + const cachedModels = cache.get('models'); + + if (cachedModels && !request.body.force) { + return response.send(cachedModels); + } + + const agent = await getClientAgent(); + const fetchResult = await fetch('https://horde.koboldai.net/api/v2/status/models?type=text', { + headers: { + 'Client-Agent': agent, + }, + }); + + const data = await fetchResult.json(); + cache.set('models', data); + return response.send(data); + } catch (error) { + console.error(error); + response.sendStatus(500); + } +}); + +router.post('/status', jsonParser, async (_, response) => { + try { + const agent = await getClientAgent(); + const fetchResult = await fetch('https://horde.koboldai.net/api/v2/status/heartbeat', { + headers: { + 'Client-Agent': agent, + }, + }); + + return response.send({ ok: fetchResult.ok }); + } catch (error) { + console.error(error); + response.sendStatus(500); + } +}); + +router.post('/cancel-task', jsonParser, async (request, response) => { + try { + const taskId = request.body.taskId; + const agent = await getClientAgent(); + const fetchResult = await fetch(`https://horde.koboldai.net/api/v2/generate/text/status/${taskId}`, { + method: 'DELETE', + headers: { + 'Client-Agent': agent, + }, + }); + + const data = await fetchResult.json(); + console.log(`Cancelled Horde task ${taskId}`); + return response.send(data); + } catch (error) { + console.error(error); + response.sendStatus(500); + } +}); + +router.post('/task-status', jsonParser, async (request, response) => { + try { + const taskId = request.body.taskId; + const agent = await getClientAgent(); + const fetchResult = await fetch(`https://horde.koboldai.net/api/v2/generate/text/status/${taskId}`, { + headers: { + 'Client-Agent': agent, + }, + }); + + const data = await fetchResult.json(); + console.log(`Horde task ${taskId} status:`, data); + return response.send(data); + } catch (error) { + console.error(error); + response.sendStatus(500); + } +}); router.post('/generate-text', jsonParser, async (request, response) => { - const api_key_horde = readSecret(SECRET_KEYS.HORDE) || ANONYMOUS_KEY; + const apiKey = readSecret(SECRET_KEYS.HORDE) || ANONYMOUS_KEY; const url = 'https://horde.koboldai.net/api/v2/generate/text/async'; + const agent = await getClientAgent(); console.log(request.body); try { @@ -59,8 +170,8 @@ router.post('/generate-text', jsonParser, async (request, response) => { body: JSON.stringify(request.body), headers: { 'Content-Type': 'application/json', - 'apikey': api_key_horde, - 'Client-Agent': String(request.header('Client-Agent')), + 'apikey': apiKey, + 'Client-Agent': agent, }, }); diff --git a/src/endpoints/thumbnails.js b/src/endpoints/thumbnails.js index 26b585b76..ad898db14 100644 --- a/src/endpoints/thumbnails.js +++ b/src/endpoints/thumbnails.js @@ -111,7 +111,8 @@ async function generateThumbnail(type, file) { try { const quality = getConfigValue('thumbnailsQuality', 95); const image = await jimp.read(pathToOriginalFile); - buffer = await image.cover(mySize[0], mySize[1]).quality(quality).getBufferAsync('image/jpeg'); + const imgType = type == 'avatar' && getConfigValue('avatarThumbnailsPng', false) ? 'image/png' : 'image/jpeg'; + buffer = await image.cover(mySize[0], mySize[1]).quality(quality).getBufferAsync(imgType); } catch (inner) { console.warn(`Thumbnailer can not process the image: ${pathToOriginalFile}. Using original size`); diff --git a/src/endpoints/translate.js b/src/endpoints/translate.js index 9bbe15391..b095940f3 100644 --- a/src/endpoints/translate.js +++ b/src/endpoints/translate.js @@ -106,6 +106,10 @@ router.post('/deepl', jsonParser, async (request, response) => { return response.sendStatus(400); } + if (request.body.lang === 'zh-CN' || request.body.lang === 'zh-TW') { + request.body.lang = 'ZH'; + } + const text = request.body.text; const lang = request.body.lang; const formality = getConfigValue('deepl.formality', 'default'); @@ -221,7 +225,7 @@ router.post('/deeplx', jsonParser, async (request, response) => { const text = request.body.text; let lang = request.body.lang; - if (request.body.lang === 'zh-CN') { + if (request.body.lang === 'zh-CN' || request.body.lang === 'zh-TW') { lang = 'ZH'; } diff --git a/src/plugin-loader.js b/src/plugin-loader.js index ba19fa603..92e566a51 100644 --- a/src/plugin-loader.js +++ b/src/plugin-loader.js @@ -1,8 +1,16 @@ const fs = require('fs'); const path = require('path'); +const url = require('url'); +const express = require('express'); const { getConfigValue } = require('./util'); const enableServerPlugins = getConfigValue('enableServerPlugins', false); +/** + * Map of loaded plugins. + * @type {Map} + */ +const loadedPlugins = new Map(); + /** * Determine if a file is a CommonJS module. * @param {string} file Path to file @@ -21,31 +29,35 @@ const isESModule = (file) => path.extname(file) === '.mjs'; * Load and initialize server plugins from a directory if they are enabled. * @param {import('express').Express} app Express app * @param {string} pluginsPath Path to plugins directory - * @returns {Promise} Promise that resolves when all plugins are loaded + * @returns {Promise} Promise that resolves when all plugins are loaded. Resolves to a "cleanup" function to + * be called before the server shuts down. */ async function loadPlugins(app, pluginsPath) { + const exitHooks = []; + const emptyFn = () => {}; + // Server plugins are disabled. if (!enableServerPlugins) { - return; + return emptyFn; } // Plugins directory does not exist. if (!fs.existsSync(pluginsPath)) { - return; + return emptyFn; } const files = fs.readdirSync(pluginsPath); // No plugins to load. if (files.length === 0) { - return; + return emptyFn; } for (const file of files) { const pluginFilePath = path.join(pluginsPath, file); if (fs.statSync(pluginFilePath).isDirectory()) { - await loadFromDirectory(app, pluginFilePath); + await loadFromDirectory(app, pluginFilePath, exitHooks); continue; } @@ -54,11 +66,14 @@ async function loadPlugins(app, pluginsPath) { continue; } - await loadFromFile(app, pluginFilePath); + await loadFromFile(app, pluginFilePath, exitHooks); } + + // Call all plugin "exit" functions at once and wait for them to finish + return () => Promise.all(exitHooks.map(exitFn => exitFn())); } -async function loadFromDirectory(app, pluginDirectoryPath) { +async function loadFromDirectory(app, pluginDirectoryPath, exitHooks) { const files = fs.readdirSync(pluginDirectoryPath); // No plugins to load. @@ -69,7 +84,7 @@ async function loadFromDirectory(app, pluginDirectoryPath) { // Plugin is an npm package. const packageJsonFilePath = path.join(pluginDirectoryPath, 'package.json'); if (fs.existsSync(packageJsonFilePath)) { - if (await loadFromPackage(app, packageJsonFilePath)) { + if (await loadFromPackage(app, packageJsonFilePath, exitHooks)) { return; } } @@ -77,7 +92,7 @@ async function loadFromDirectory(app, pluginDirectoryPath) { // Plugin is a CommonJS module. const cjsFilePath = path.join(pluginDirectoryPath, 'index.js'); if (fs.existsSync(cjsFilePath)) { - if (await loadFromFile(app, cjsFilePath)) { + if (await loadFromFile(app, cjsFilePath, exitHooks)) { return; } } @@ -85,7 +100,7 @@ async function loadFromDirectory(app, pluginDirectoryPath) { // Plugin is an ECMAScript module. const esmFilePath = path.join(pluginDirectoryPath, 'index.mjs'); if (fs.existsSync(esmFilePath)) { - if (await loadFromFile(app, esmFilePath)) { + if (await loadFromFile(app, esmFilePath, exitHooks)) { return; } } @@ -95,14 +110,16 @@ async function loadFromDirectory(app, pluginDirectoryPath) { * Loads and initializes a plugin from an npm package. * @param {import('express').Express} app Express app * @param {string} packageJsonPath Path to package.json file + * @param {Array} exitHooks Array of functions to be run on plugin exit. Will be pushed to if the plugin has + * an "exit" function. * @returns {Promise} Promise that resolves to true if plugin was loaded successfully */ -async function loadFromPackage(app, packageJsonPath) { +async function loadFromPackage(app, packageJsonPath, exitHooks) { try { const packageJson = JSON.parse(fs.readFileSync(packageJsonPath, 'utf8')); if (packageJson.main) { const pluginFilePath = path.join(path.dirname(packageJsonPath), packageJson.main); - return await loadFromFile(app, pluginFilePath); + return await loadFromFile(app, pluginFilePath, exitHooks); } } catch (error) { console.error(`Failed to load plugin from ${packageJsonPath}: ${error}`); @@ -114,13 +131,16 @@ async function loadFromPackage(app, packageJsonPath) { * Loads and initializes a plugin from a file. * @param {import('express').Express} app Express app * @param {string} pluginFilePath Path to plugin directory + * @param {Array.} exitHooks Array of functions to be run on plugin exit. Will be pushed to if the plugin has + * an "exit" function. * @returns {Promise} Promise that resolves to true if plugin was loaded successfully */ -async function loadFromFile(app, pluginFilePath) { +async function loadFromFile(app, pluginFilePath, exitHooks) { try { - const plugin = await getPluginModule(pluginFilePath); + const fileUrl = url.pathToFileURL(pluginFilePath).toString(); + const plugin = await import(fileUrl); console.log(`Initializing plugin from ${pluginFilePath}`); - return await initPlugin(app, plugin); + return await initPlugin(app, plugin, exitHooks); } catch (error) { console.error(`Failed to load plugin from ${pluginFilePath}: ${error}`); return false; @@ -128,33 +148,72 @@ async function loadFromFile(app, pluginFilePath) { } /** - * Initializes a plugin module. - * @param {import('express').Express} app Express app - * @param {any} plugin Plugin module - * @returns {Promise} Promise that resolves to true if plugin was initialized successfully + * Check whether a plugin ID is valid (only lowercase alphanumeric, hyphens, and underscores). + * @param {string} id The plugin ID to check + * @returns {boolean} True if the plugin ID is valid. */ -async function initPlugin(app, plugin) { - if (typeof plugin.init === 'function') { - await plugin.init(app); - return true; - } - - return false; +function isValidPluginID(id) { + return /^[a-z0-9_-]+$/.test(id); } /** - * Loads a module from a file depending on the module type. - * @param {string} pluginFilePath Path to plugin file - * @returns {Promise} Promise that resolves to plugin module + * Initializes a plugin module. + * @param {import('express').Express} app Express app + * @param {any} plugin Plugin module + * @param {Array.} exitHooks Array of functions to be run on plugin exit. Will be pushed to if the plugin has + * an "exit" function. + * @returns {Promise} Promise that resolves to true if plugin was initialized successfully */ -async function getPluginModule(pluginFilePath) { - if (isCommonJS(pluginFilePath)) { - return require(pluginFilePath); +async function initPlugin(app, plugin, exitHooks) { + const info = plugin.info || plugin.default?.info; + if (typeof info !== 'object') { + console.error('Failed to load plugin module; plugin info not found'); + return false; } - if (isESModule(pluginFilePath)) { - return await import(pluginFilePath); + + // We don't currently use "name" or "description" but it would be nice to have a UI for listing server plugins, so + // require them now just to be safe + for (const field of ['id', 'name', 'description']) { + if (typeof info[field] !== 'string') { + console.error(`Failed to load plugin module; plugin info missing field '${field}'`); + return false; + } } - throw new Error(`Unsupported module type in ${pluginFilePath}`); + + if (typeof plugin.init !== 'function') { + console.error('Failed to load plugin module; no init function'); + return false; + } + + const { id } = info; + + if (!isValidPluginID(id)) { + console.error(`Failed to load plugin module; invalid plugin ID '${id}'`); + return false; + } + + if (loadedPlugins.has(id)) { + console.error(`Failed to load plugin module; plugin ID '${id}' is already in use`); + return false; + } + + // Allow the plugin to register API routes under /api/plugins/[plugin ID] via a router + const router = express.Router(); + + await plugin.init(router); + + loadedPlugins.set(id, plugin); + + // Add API routes to the app if the plugin registered any + if (router.stack.length > 0) { + app.use(`/api/plugins/${id}`, router); + } + + if (typeof plugin.exit === 'function') { + exitHooks.push(plugin.exit); + } + + return true; } module.exports = { diff --git a/src/util.js b/src/util.js index ef5da97ff..1d437379e 100644 --- a/src/util.js +++ b/src/util.js @@ -467,6 +467,61 @@ function trimV1(str) { return String(str ?? '').replace(/\/$/, '').replace(/\/v1$/, ''); } +/** + * Simple TTL memory cache. + */ +class Cache { + /** + * @param {number} ttl Time to live in milliseconds + */ + constructor(ttl) { + this.cache = new Map(); + this.ttl = ttl; + } + + /** + * Gets a value from the cache. + * @param {string} key Cache key + */ + get(key) { + const value = this.cache.get(key); + if (value?.expiry > Date.now()) { + return value.value; + } + + // Cache miss or expired, remove the key + this.cache.delete(key); + return null; + } + + /** + * Sets a value in the cache. + * @param {string} key Key + * @param {object} value Value + */ + set(key, value) { + this.cache.set(key, { + value: value, + expiry: Date.now() + this.ttl, + }); + } + + /** + * Removes a value from the cache. + * @param {string} key Key + */ + remove(key) { + this.cache.delete(key); + } + + /** + * Clears the cache. + */ + clear() { + this.cache.clear(); + } +} + module.exports = { getConfig, getConfigValue, @@ -491,4 +546,5 @@ module.exports = { mergeObjectWithYaml, excludeKeysByYaml, trimV1, + Cache, };