From 66ec17620f976fe2fc3e78a89adeca6c73d98f26 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Tue, 12 Sep 2023 20:45:36 +0300 Subject: [PATCH] Move Horde and SD endpoints into separate files --- .../extensions/stable-diffusion/index.js | 6 +- public/scripts/horde.js | 6 +- server.js | 420 +----------------- src/{horde => ai_horde}/LICENSE.md | 0 src/{horde => ai_horde}/index.d.ts | 0 src/{horde => ai_horde}/index.js | 0 src/{horde => ai_horde}/index.mjs | 0 src/horde.js | 169 +++++++ src/stable-diffusion.js | 247 ++++++++++ src/util.js | 67 ++- 10 files changed, 492 insertions(+), 423 deletions(-) rename src/{horde => ai_horde}/LICENSE.md (100%) rename src/{horde => ai_horde}/index.d.ts (100%) rename src/{horde => ai_horde}/index.js (100%) rename src/{horde => ai_horde}/index.mjs (100%) create mode 100644 src/horde.js create mode 100644 src/stable-diffusion.js diff --git a/public/scripts/extensions/stable-diffusion/index.js b/public/scripts/extensions/stable-diffusion/index.js index 3168a20b7..886cb8141 100644 --- a/public/scripts/extensions/stable-diffusion/index.js +++ b/public/scripts/extensions/stable-diffusion/index.js @@ -622,7 +622,7 @@ async function loadSamplers() { } async function loadHordeSamplers() { - const result = await fetch('/horde_samplers', { + const result = await fetch('/api/horde/sd-samplers', { method: 'POST', headers: getRequestHeaders(), }); @@ -721,7 +721,7 @@ async function loadModels() { } async function loadHordeModels() { - const result = await fetch('/horde_models', { + const result = await fetch('/api/horde/sd-models', { method: 'POST', headers: getRequestHeaders(), }); @@ -1084,7 +1084,7 @@ async function generateExtrasImage(prompt) { * @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete. */ async function generateHordeImage(prompt) { - const result = await fetch('/horde_generateimage', { + const result = await fetch('/api/horde/generate-image', { method: 'POST', headers: getRequestHeaders(), body: JSON.stringify({ diff --git a/public/scripts/horde.js b/public/scripts/horde.js index 1c66e0621..86a791993 100644 --- a/public/scripts/horde.js +++ b/public/scripts/horde.js @@ -107,7 +107,7 @@ async function generateHorde(prompt, params, signal) { "models": horde_settings.models, }; - const response = await fetch("/generate_horde", { + const response = await fetch("/api/horde/generate-text", { method: 'POST', headers: { ...getRequestHeaders(), @@ -210,7 +210,7 @@ function loadHordeSettings(settings) { } async function showKudos() { - const response = await fetch('/horde_userinfo', { + const response = await fetch('/api/horde/user-info', { method: 'POST', headers: getRequestHeaders(), }); @@ -256,7 +256,7 @@ jQuery(function () { }) $("#horde_api_key").on("input", async function () { - const key = $(this).val().trim(); + const key = String($(this).val()).trim(); await writeSecret(SECRET_KEYS.HORDE, key); }); diff --git a/server.js b/server.js index 80c47884e..d0cb96d0e 100644 --- a/server.js +++ b/server.js @@ -1,7 +1,6 @@ #!/usr/bin/env node // native node modules -const child_process = require('child_process') const crypto = require('crypto'); const fs = require('fs'); const http = require("http"); @@ -14,7 +13,6 @@ const { finished } = require('stream/promises'); const { TextEncoder, TextDecoder } = require('util'); // cli/fs related library imports -const commandExistsSync = require('command-exists').sync; const open = require('open'); const sanitize = require('sanitize-filename'); const simpleGit = require('simple-git'); @@ -67,13 +65,13 @@ util.inspect.defaultOptions.maxStringLength = null; createDefaultFiles(); // local library imports -const AIHorde = require("./src/horde"); const basicAuthMiddleware = require('./src/middleware/basicAuthMiddleware'); const characterCardParser = require('./src/character-card-parser.js'); const contentManager = require('./src/content-manager'); const novelai = require('./src/novelai'); const statsHelpers = require('./statsHelpers.js'); const { writeSecret, readSecret, readSecretState, migrateSecrets, SECRET_KEYS, getAllSecrets } = require('./src/secrets'); +const { delay, getVersion } = require('./src/util'); function createDefaultFiles() { const files = { @@ -158,13 +156,6 @@ const enableExtensions = config.enableExtensions; const listen = config.listen; const allowKeysExposure = config.allowKeysExposure; -function getHordeClient() { - const ai_horde = new AIHorde({ - client_agent: getVersion()?.agent || 'SillyTavern:UNKNOWN:Cohee#1207', - }); - return ai_horde; -} - const API_NOVELAI = "https://api.novelai.net"; const API_OPENAI = "https://api.openai.com/v1"; const API_CLAUDE = "https://api.anthropic.com/v1"; @@ -206,16 +197,6 @@ function getOverrideHeaders(urlHost) { } } -/** - * Encodes the Basic Auth header value for the given user and password. - * @param {string} auth username:password - * @returns {string} Basic Auth header value - */ -function getBasicAuthHeader(auth) { - const encoded = Buffer.from(`${auth}`).toString('base64'); - return `Basic ${encoded}`; -} - //RossAscends: Added function to format dates used in files and chat timestamps to a humanized format. //Mostly I wanted this to be for file names, but couldn't figure out exactly where the filename save code was as everything seemed to be connected. //During testing, this performs the same as previous date.now() structure. @@ -223,8 +204,6 @@ function getBasicAuthHeader(auth) { //New chats made with characters will use this new formatting. //Useable variable is (( humanizedISO8601Datetime )) -const delay = ms => new Promise(resolve => setTimeout(resolve, ms)) - const CHARS_PER_TOKEN = 3.35; let spp_llama; @@ -911,30 +890,6 @@ app.post("/getstatus", jsonParser, async function (request, response) { } }); -function getVersion() { - let pkgVersion = 'UNKNOWN'; - let gitRevision = null; - let gitBranch = null; - try { - const pkgJson = require('./package.json'); - pkgVersion = pkgJson.version; - if (!process['pkg'] && commandExistsSync('git')) { - gitRevision = child_process - .execSync('git rev-parse --short HEAD', { cwd: process.cwd(), stdio: ['ignore', 'pipe', 'ignore'] }) - .toString().trim(); - - gitBranch = child_process - .execSync('git rev-parse --abbrev-ref HEAD', { cwd: process.cwd(), stdio: ['ignore', 'pipe', 'ignore'] }) - .toString().trim(); - } - } - catch { - // suppress exception - } - - const agent = `SillyTavern:${pkgVersion}:Cohee#1207`; - return { agent, pkgVersion, gitRevision, gitBranch }; -} function tryParse(str) { try { @@ -4283,38 +4238,6 @@ app.post('/readsecretstate', jsonParser, (_, response) => { } }); -const ANONYMOUS_KEY = "0000000000"; - -app.post('/generate_horde', jsonParser, async (request, response) => { - const api_key_horde = readSecret(SECRET_KEYS.HORDE) || ANONYMOUS_KEY; - const url = 'https://horde.koboldai.net/api/v2/generate/text/async'; - - const args = { - "body": JSON.stringify(request.body), - "headers": { - "Content-Type": "application/json", - "apikey": api_key_horde, - } - }; - if (request.header('Client-Agent') !== undefined) args.headers['Client-Agent'] = request.header('Client-Agent'); - - console.log(request.body); - try { - const data = await postAsync(url, args); - return response.send(data); - } catch (error) { - console.log('Horde returned an error:', error.statusText); - - if (typeof error.text === 'function') { - const message = await error.text(); - console.log(message); - return response.send({ error: { message } }); - } else { - return response.send({ error: true }); - } - } -}); - app.post('/viewsecrets', jsonParser, async (_, response) => { if (!allowKeysExposure) { console.error('secrets.json could not be viewed unless the value of allowKeysExposure in config.conf is set to true'); @@ -4335,109 +4258,6 @@ app.post('/viewsecrets', jsonParser, async (_, response) => { } }); -app.post('/horde_samplers', jsonParser, async (_, response) => { - try { - const ai_horde = getHordeClient(); - const samplers = Object.values(ai_horde.ModelGenerationInputStableSamplers); - response.send(samplers); - } catch (error) { - console.error(error); - response.sendStatus(500); - } -}); - -app.post('/horde_models', jsonParser, async (_, response) => { - try { - const ai_horde = getHordeClient(); - const models = await ai_horde.getModels(); - response.send(models); - } catch (error) { - console.error(error); - response.sendStatus(500); - } -}); - -app.post('/horde_userinfo', jsonParser, async (_, response) => { - const api_key_horde = readSecret(SECRET_KEYS.HORDE); - - if (!api_key_horde) { - return response.send({ anonymous: true }); - } - - try { - const ai_horde = getHordeClient(); - const user = await ai_horde.findUser({ token: api_key_horde }); - return response.send(user); - } catch (error) { - console.error(error); - return response.sendStatus(500); - } -}) - -app.post('/horde_generateimage', jsonParser, async (request, response) => { - const MAX_ATTEMPTS = 200; - const CHECK_INTERVAL = 3000; - const api_key_horde = readSecret(SECRET_KEYS.HORDE) || ANONYMOUS_KEY; - console.log('Stable Horde request:', request.body); - - try { - const ai_horde = getHordeClient(); - const generation = await ai_horde.postAsyncImageGenerate( - { - prompt: `${request.body.prompt} ### ${request.body.negative_prompt}`, - params: - { - sampler_name: request.body.sampler, - hires_fix: request.body.enable_hr, - // @ts-ignore - use_gfpgan param is not in the type definition, need to update to new ai_horde @ https://github.com/ZeldaFan0225/ai_horde/blob/main/index.ts - use_gfpgan: request.body.restore_faces, - cfg_scale: request.body.scale, - steps: request.body.steps, - width: request.body.width, - height: request.body.height, - karras: Boolean(request.body.karras), - n: 1, - }, - r2: false, - nsfw: request.body.nfsw, - models: [request.body.model], - }, - { token: api_key_horde }); - - if (!generation.id) { - console.error('Image generation request is not satisfyable:', generation.message || 'unknown error'); - return response.sendStatus(400); - } - - for (let attempt = 0; attempt < MAX_ATTEMPTS; attempt++) { - await delay(CHECK_INTERVAL); - const check = await ai_horde.getImageGenerationCheck(generation.id); - console.log(check); - - if (check.done) { - const result = await ai_horde.getImageGenerationStatus(generation.id); - if (result.generations === undefined) return response.sendStatus(500); - return response.send(result.generations[0].img); - } - - /* - if (!check.is_possible) { - return response.sendStatus(503); - } - */ - - if (check.faulted) { - return response.sendStatus(500); - } - } - - return response.sendStatus(504); - } catch (error) { - console.error(error); - return response.sendStatus(500); - } -}); - app.post('/api/novelai/generate-image', jsonParser, async (request, response) => { if (!request.body) { return response.sendStatus(400); @@ -4528,240 +4348,6 @@ app.post('/api/novelai/generate-image', jsonParser, async (request, response) => } }); -app.post('/api/sd/ping', jsonParser, async (request, response) => { - try { - const url = new URL(request.body.url); - url.pathname = '/internal/ping'; - - const result = await fetch(url, { - method: 'GET', - headers: { - 'Authorization': getBasicAuthHeader(request.body.auth), - } - }); - - if (!result.ok) { - throw new Error('SD WebUI returned an error.'); - } - - return response.sendStatus(200); - } catch (error) { - console.log(error); - return response.sendStatus(500); - } -}); - -app.post('/api/sd/upscalers', jsonParser, async (request, response) => { - try { - async function getUpscalerModels() { - const url = new URL(request.body.url); - url.pathname = '/sdapi/v1/upscalers'; - - const result = await fetch(url, { - method: 'GET', - headers: { - 'Authorization': getBasicAuthHeader(request.body.auth), - }, - }); - - if (!result.ok) { - throw new Error('SD WebUI returned an error.'); - } - - const data = await result.json(); - const names = data.map(x => x.name); - return names; - } - - async function getLatentUpscalers() { - const url = new URL(request.body.url); - url.pathname = '/sdapi/v1/latent-upscale-modes'; - - const result = await fetch(url, { - method: 'GET', - headers: { - 'Authorization': getBasicAuthHeader(request.body.auth), - }, - }); - - if (!result.ok) { - throw new Error('SD WebUI returned an error.'); - } - - const data = await result.json(); - const names = data.map(x => x.name); - return names; - } - - const [upscalers, latentUpscalers] = await Promise.all([getUpscalerModels(), getLatentUpscalers()]); - - // 0 = None, then Latent Upscalers, then Upscalers - upscalers.splice(1, 0, ...latentUpscalers); - - return response.send(upscalers); - } catch (error) { - console.log(error); - return response.sendStatus(500); - } -}); - -app.post('/api/sd/samplers', jsonParser, async (request, response) => { - try { - const url = new URL(request.body.url); - url.pathname = '/sdapi/v1/samplers'; - - const result = await fetch(url, { - method: 'GET', - headers: { - 'Authorization': getBasicAuthHeader(request.body.auth), - }, - }); - - if (!result.ok) { - throw new Error('SD WebUI returned an error.'); - } - - const data = await result.json(); - const names = data.map(x => x.name); - return response.send(names); - - } catch (error) { - console.log(error); - return response.sendStatus(500); - } -}); - -app.post('/api/sd/models', jsonParser, async (request, response) => { - try { - const url = new URL(request.body.url); - url.pathname = '/sdapi/v1/sd-models'; - - const result = await fetch(url, { - method: 'GET', - headers: { - 'Authorization': getBasicAuthHeader(request.body.auth), - }, - }); - - if (!result.ok) { - throw new Error('SD WebUI returned an error.'); - } - - const data = await result.json(); - const models = data.map(x => ({ value: x.title, text: x.title })); - return response.send(models); - } catch (error) { - console.log(error); - return response.sendStatus(500); - } -}); - -app.post('/api/sd/get-model', jsonParser, async (request, response) => { - try { - const url = new URL(request.body.url); - url.pathname = '/sdapi/v1/options'; - - const result = await fetch(url, { - method: 'GET', - headers: { - 'Authorization': getBasicAuthHeader(request.body.auth), - }, - }); - const data = await result.json(); - return response.send(data['sd_model_checkpoint']); - } catch (error) { - console.log(error); - return response.sendStatus(500); - } -}); - -app.post('/api/sd/set-model', jsonParser, async (request, response) => { - try { - async function getProgress() { - const url = new URL(request.body.url); - url.pathname = '/sdapi/v1/progress'; - - const result = await fetch(url, { - method: 'GET', - headers: { - 'Authorization': getBasicAuthHeader(request.body.auth), - }, - }); - const data = await result.json(); - return data; - } - - const url = new URL(request.body.url); - url.pathname = '/sdapi/v1/options'; - - const options = { - sd_model_checkpoint: request.body.model, - }; - - const result = await fetch(url, { - method: 'POST', - body: JSON.stringify(options), - headers: { - 'Content-Type': 'application/json', - 'Authorization': getBasicAuthHeader(request.body.auth), - }, - }); - - if (!result.ok) { - throw new Error('SD WebUI returned an error.'); - } - - const MAX_ATTEMPTS = 10; - const CHECK_INTERVAL = 2000; - - for (let attempt = 0; attempt < MAX_ATTEMPTS; attempt++) { - const progressState = await getProgress(); - - const progress = progressState["progress"] - const jobCount = progressState["state"]["job_count"]; - if (progress == 0.0 && jobCount === 0) { - break; - } - - console.log(`Waiting for SD WebUI to finish model loading... Progress: ${progress}; Job count: ${jobCount}`); - await delay(CHECK_INTERVAL); - } - - return response.sendStatus(200); - } catch (error) { - console.log(error); - return response.sendStatus(500); - } -}); - -app.post('/api/sd/generate', jsonParser, async (request, response) => { - try { - console.log('SD WebUI request:', request.body); - - const url = new URL(request.body.url); - url.pathname = '/sdapi/v1/txt2img'; - - const result = await fetch(url, { - method: 'POST', - body: JSON.stringify(request.body), - headers: { - 'Content-Type': 'application/json', - 'Authorization': getBasicAuthHeader(request.body.auth), - }, - }); - - if (!result.ok) { - throw new Error('SD WebUI returned an error.'); - } - - const data = await result.json(); - return response.send(data); - } catch (error) { - console.log(error); - return response.sendStatus(500); - } -}); - app.post('/novel_tts', jsonParser, async (request, response) => { const token = readSecret(SECRET_KEYS.NOVEL); @@ -5656,6 +5242,10 @@ app.post('/get_character_assets_list', jsonParser, async (request, response) => } }); +// Stable Diffusion generation +require('./src/stable-diffusion').registerEndpoints(app, jsonParser); +// LLM and SD Horde generation +require('./src/horde').registerEndpoints(app, jsonParser); // Vector storage DB require('./src/vectors').registerEndpoints(app, jsonParser); // Chat translation diff --git a/src/horde/LICENSE.md b/src/ai_horde/LICENSE.md similarity index 100% rename from src/horde/LICENSE.md rename to src/ai_horde/LICENSE.md diff --git a/src/horde/index.d.ts b/src/ai_horde/index.d.ts similarity index 100% rename from src/horde/index.d.ts rename to src/ai_horde/index.d.ts diff --git a/src/horde/index.js b/src/ai_horde/index.js similarity index 100% rename from src/horde/index.js rename to src/ai_horde/index.js diff --git a/src/horde/index.mjs b/src/ai_horde/index.mjs similarity index 100% rename from src/horde/index.mjs rename to src/ai_horde/index.mjs diff --git a/src/horde.js b/src/horde.js new file mode 100644 index 000000000..43830f8fc --- /dev/null +++ b/src/horde.js @@ -0,0 +1,169 @@ +const fetch = require('node-fetch').default; +const AIHorde = require("./ai_horde"); +const { getVersion, delay } = require("./util"); +const { readSecret, SECRET_KEYS } = require("./secrets"); + +const ANONYMOUS_KEY = "0000000000"; + +function getHordeClient() { + const ai_horde = new AIHorde({ + client_agent: getVersion()?.agent || 'SillyTavern:UNKNOWN:Cohee#1207', + }); + return ai_horde; +} + +/** + * + * @param {import("express").Express} app + * @param {any} jsonParser + */ +function registerEndpoints(app, jsonParser) { + app.post('/api/horde/generate-text', jsonParser, async (request, response) => { + const api_key_horde = readSecret(SECRET_KEYS.HORDE) || ANONYMOUS_KEY; + const url = 'https://horde.koboldai.net/api/v2/generate/text/async'; + + console.log(request.body); + try { + const result = await fetch(url, { + method: 'POST', + body: JSON.stringify(request.body), + headers: { + "Content-Type": "application/json", + "apikey": api_key_horde, + 'Client-Agent': String(request.header('Client-Agent')), + } + }); + + if (!result.ok) { + const message = await result.text(); + console.log('Horde returned an error:', message); + return response.send({ error: { message } }); + } + + const data = await result.json(); + return response.send(data); + } catch (error) { + console.log(error); + return response.send({ error: true }); + } + }); + + app.post('/api/horde/sd-samplers', jsonParser, async (_, response) => { + try { + const ai_horde = getHordeClient(); + const samplers = Object.values(ai_horde.ModelGenerationInputStableSamplers); + response.send(samplers); + } catch (error) { + console.error(error); + response.sendStatus(500); + } + }); + + app.post('/api/horde/sd-models', jsonParser, async (_, response) => { + try { + const ai_horde = getHordeClient(); + const models = await ai_horde.getModels(); + response.send(models); + } catch (error) { + console.error(error); + response.sendStatus(500); + } + }); + + app.post('/api/horde/user-info', jsonParser, async (_, response) => { + const api_key_horde = readSecret(SECRET_KEYS.HORDE); + + if (!api_key_horde) { + return response.send({ anonymous: true }); + } + + try { + const ai_horde = getHordeClient(); + const user = await ai_horde.findUser({ token: api_key_horde }); + return response.send(user); + } catch (error) { + console.error(error); + return response.sendStatus(500); + } + }) + + app.post('/api/horde/generate-image', jsonParser, async (request, response) => { + if (!request.body.prompt) { + return response.sendStatus(400); + } + + const MAX_ATTEMPTS = 200; + const CHECK_INTERVAL = 3000; + const PROMPT_THRESHOLD = 1000; + + try { + const maxLength = PROMPT_THRESHOLD - String(request.body.negative_prompt).length - 5; + if (String(request.body.prompt).length > maxLength) { + console.log('Stable Horde prompt is too long, truncating...'); + request.body.prompt = String(request.body.prompt).substring(0, maxLength); + } + + const api_key_horde = readSecret(SECRET_KEYS.HORDE) || ANONYMOUS_KEY; + console.log('Stable Horde request:', request.body); + + const ai_horde = getHordeClient(); + const generation = await ai_horde.postAsyncImageGenerate( + { + prompt: `${request.body.prompt} ### ${request.body.negative_prompt}`, + params: + { + sampler_name: request.body.sampler, + hires_fix: request.body.enable_hr, + // @ts-ignore - use_gfpgan param is not in the type definition, need to update to new ai_horde @ https://github.com/ZeldaFan0225/ai_horde/blob/main/index.ts + use_gfpgan: request.body.restore_faces, + cfg_scale: request.body.scale, + steps: request.body.steps, + width: request.body.width, + height: request.body.height, + karras: Boolean(request.body.karras), + n: 1, + }, + r2: false, + nsfw: request.body.nfsw, + models: [request.body.model], + }, + { token: api_key_horde }); + + if (!generation.id) { + console.error('Image generation request is not satisfyable:', generation.message || 'unknown error'); + return response.sendStatus(400); + } + + for (let attempt = 0; attempt < MAX_ATTEMPTS; attempt++) { + await delay(CHECK_INTERVAL); + const check = await ai_horde.getImageGenerationCheck(generation.id); + console.log(check); + + if (check.done) { + const result = await ai_horde.getImageGenerationStatus(generation.id); + if (result.generations === undefined) return response.sendStatus(500); + return response.send(result.generations[0].img); + } + + /* + if (!check.is_possible) { + return response.sendStatus(503); + } + */ + + if (check.faulted) { + return response.sendStatus(500); + } + } + + return response.sendStatus(504); + } catch (error) { + console.error(error); + return response.sendStatus(500); + } + }); +} + +module.exports = { + registerEndpoints, +}; diff --git a/src/stable-diffusion.js b/src/stable-diffusion.js new file mode 100644 index 000000000..438dfd588 --- /dev/null +++ b/src/stable-diffusion.js @@ -0,0 +1,247 @@ +const fetch = require('node-fetch').default; +const { getBasicAuthHeader, delay } = require('./util'); + +/** + * Registers the endpoints for the Stable Diffusion API extension. + * @param {import("express").Express} app Express app + * @param {any} jsonParser JSON parser middleware + */ +function registerEndpoints(app, jsonParser) { + app.post('/api/sd/ping', jsonParser, async (request, response) => { + try { + const url = new URL(request.body.url); + url.pathname = '/internal/ping'; + + const result = await fetch(url, { + method: 'GET', + headers: { + 'Authorization': getBasicAuthHeader(request.body.auth), + }, + }); + + if (!result.ok) { + throw new Error('SD WebUI returned an error.'); + } + + return response.sendStatus(200); + } catch (error) { + console.log(error); + return response.sendStatus(500); + } + }); + + app.post('/api/sd/upscalers', jsonParser, async (request, response) => { + try { + async function getUpscalerModels() { + const url = new URL(request.body.url); + url.pathname = '/sdapi/v1/upscalers'; + + const result = await fetch(url, { + method: 'GET', + headers: { + 'Authorization': getBasicAuthHeader(request.body.auth), + }, + }); + + if (!result.ok) { + throw new Error('SD WebUI returned an error.'); + } + + const data = await result.json(); + const names = data.map(x => x.name); + return names; + } + + async function getLatentUpscalers() { + const url = new URL(request.body.url); + url.pathname = '/sdapi/v1/latent-upscale-modes'; + + const result = await fetch(url, { + method: 'GET', + headers: { + 'Authorization': getBasicAuthHeader(request.body.auth), + }, + }); + + if (!result.ok) { + throw new Error('SD WebUI returned an error.'); + } + + const data = await result.json(); + const names = data.map(x => x.name); + return names; + } + + const [upscalers, latentUpscalers] = await Promise.all([getUpscalerModels(), getLatentUpscalers()]); + + // 0 = None, then Latent Upscalers, then Upscalers + upscalers.splice(1, 0, ...latentUpscalers); + + return response.send(upscalers); + } catch (error) { + console.log(error); + return response.sendStatus(500); + } + }); + + app.post('/api/sd/samplers', jsonParser, async (request, response) => { + try { + const url = new URL(request.body.url); + url.pathname = '/sdapi/v1/samplers'; + + const result = await fetch(url, { + method: 'GET', + headers: { + 'Authorization': getBasicAuthHeader(request.body.auth), + }, + }); + + if (!result.ok) { + throw new Error('SD WebUI returned an error.'); + } + + const data = await result.json(); + const names = data.map(x => x.name); + return response.send(names); + + } catch (error) { + console.log(error); + return response.sendStatus(500); + } + }); + + app.post('/api/sd/models', jsonParser, async (request, response) => { + try { + const url = new URL(request.body.url); + url.pathname = '/sdapi/v1/sd-models'; + + const result = await fetch(url, { + method: 'GET', + headers: { + 'Authorization': getBasicAuthHeader(request.body.auth), + }, + }); + + if (!result.ok) { + throw new Error('SD WebUI returned an error.'); + } + + const data = await result.json(); + const models = data.map(x => ({ value: x.title, text: x.title })); + return response.send(models); + } catch (error) { + console.log(error); + return response.sendStatus(500); + } + }); + + app.post('/api/sd/get-model', jsonParser, async (request, response) => { + try { + const url = new URL(request.body.url); + url.pathname = '/sdapi/v1/options'; + + const result = await fetch(url, { + method: 'GET', + headers: { + 'Authorization': getBasicAuthHeader(request.body.auth), + }, + }); + const data = await result.json(); + return response.send(data['sd_model_checkpoint']); + } catch (error) { + console.log(error); + return response.sendStatus(500); + } + }); + + app.post('/api/sd/set-model', jsonParser, async (request, response) => { + try { + async function getProgress() { + const url = new URL(request.body.url); + url.pathname = '/sdapi/v1/progress'; + + const result = await fetch(url, { + method: 'GET', + headers: { + 'Authorization': getBasicAuthHeader(request.body.auth), + }, + }); + const data = await result.json(); + return data; + } + + const url = new URL(request.body.url); + url.pathname = '/sdapi/v1/options'; + + const options = { + sd_model_checkpoint: request.body.model, + }; + + const result = await fetch(url, { + method: 'POST', + body: JSON.stringify(options), + headers: { + 'Content-Type': 'application/json', + 'Authorization': getBasicAuthHeader(request.body.auth), + }, + }); + + if (!result.ok) { + throw new Error('SD WebUI returned an error.'); + } + + const MAX_ATTEMPTS = 10; + const CHECK_INTERVAL = 2000; + + for (let attempt = 0; attempt < MAX_ATTEMPTS; attempt++) { + const progressState = await getProgress(); + + const progress = progressState["progress"] + const jobCount = progressState["state"]["job_count"]; + if (progress == 0.0 && jobCount === 0) { + break; + } + + console.log(`Waiting for SD WebUI to finish model loading... Progress: ${progress}; Job count: ${jobCount}`); + await delay(CHECK_INTERVAL); + } + + return response.sendStatus(200); + } catch (error) { + console.log(error); + return response.sendStatus(500); + } + }); + + app.post('/api/sd/generate', jsonParser, async (request, response) => { + try { + console.log('SD WebUI request:', request.body); + + const url = new URL(request.body.url); + url.pathname = '/sdapi/v1/txt2img'; + + const result = await fetch(url, { + method: 'POST', + body: JSON.stringify(request.body), + headers: { + 'Content-Type': 'application/json', + 'Authorization': getBasicAuthHeader(request.body.auth), + }, + }); + + if (!result.ok) { + throw new Error('SD WebUI returned an error.'); + } + + const data = await result.json(); + return response.send(data); + } catch (error) { + console.log(error); + return response.sendStatus(500); + } + }); +} + +module.exports = { + registerEndpoints, +}; diff --git a/src/util.js b/src/util.js index a870eb169..3df16a896 100644 --- a/src/util.js +++ b/src/util.js @@ -1,10 +1,73 @@ const path = require('path'); +const child_process = require('child_process'); +const commandExistsSync = require('command-exists').sync; +/** + * Returns the config object from the config.conf file. + * @returns {object} Config object + */ function getConfig() { - const config = require(path.join(process.cwd(), './config.conf')); - return config; + try { + const config = require(path.join(process.cwd(), './config.conf')); + return config; + } catch (error) { + console.warn('Failed to read config.conf'); + return {}; + } +} + +/** + * Encodes the Basic Auth header value for the given user and password. + * @param {string} auth username:password + * @returns {string} Basic Auth header value + */ +function getBasicAuthHeader(auth) { + const encoded = Buffer.from(`${auth}`).toString('base64'); + return `Basic ${encoded}`; +} + +/** + * Returns the version of the running instance. Get the version from the package.json file and the git revision. + * Also returns the agent string for the Horde API. + * @returns {{agent: string, pkgVersion: string, gitRevision: string | null, gitBranch: string | null}} Version info object + */ +function getVersion() { + let pkgVersion = 'UNKNOWN'; + let gitRevision = null; + let gitBranch = null; + try { + const pkgJson = require(path.join(process.cwd(), './package.json')); + pkgVersion = pkgJson.version; + if (!process['pkg'] && commandExistsSync('git')) { + gitRevision = child_process + .execSync('git rev-parse --short HEAD', { cwd: process.cwd(), stdio: ['ignore', 'pipe', 'ignore'] }) + .toString().trim(); + + gitBranch = child_process + .execSync('git rev-parse --abbrev-ref HEAD', { cwd: process.cwd(), stdio: ['ignore', 'pipe', 'ignore'] }) + .toString().trim(); + } + } + catch { + // suppress exception + } + + const agent = `SillyTavern:${pkgVersion}:Cohee#1207`; + return { agent, pkgVersion, gitRevision, gitBranch }; +} + +/** + * Delays the current async function by the given amount of milliseconds. + * @param {number} ms Milliseconds to wait + * @returns {Promise} Promise that resolves after the given amount of milliseconds + */ +function delay(ms) { + return new Promise(resolve => setTimeout(resolve, ms)); } module.exports = { getConfig, + getVersion, + getBasicAuthHeader, + delay, };