Move Stability generation to backend

This commit is contained in:
Cohee 2024-07-04 22:36:17 +03:00
parent 6608e530c5
commit e32b0cc223
5 changed files with 198 additions and 161 deletions

View File

@ -22,7 +22,7 @@ import { getApiUrl, getContext, extension_settings, doExtrasFetch, modules, rend
import { selected_group } from '../../group-chats.js';
import { stringFormat, initScrollHeight, resetScrollHeight, getCharaFilename, saveBase64AsFile, getBase64Async, delay, isTrueBoolean, debounce } from '../../utils.js';
import { getMessageTimeStamp, humanizedDateTime } from '../../RossAscends-mods.js';
import { SECRET_KEYS, secret_state } from '../../secrets.js';
import { SECRET_KEYS, secret_state, writeSecret } from '../../secrets.js';
import { getNovelUnlimitedImageGeneration, getNovelAnlas, loadNovelSubscriptionData } from '../../nai-settings.js';
import { getMultimodalCaption } from '../shared.js';
import { SlashCommandParser } from '../../slash-commands/SlashCommandParser.js';
@ -285,11 +285,7 @@ const defaultSettings = {
interactive_visible: false,
// Stability AI settings
stability_api_key: '',
stability_engine: 'V2beta Image Generation',
stability_style_preset: "anime",
stability_aspect_ratio: '1:1',
stability_output_format: 'png',
stability_style_preset: 'anime',
};
const writePromptFieldsDebounced = debounce(writePromptFields, debounce_timeout.relaxed);
@ -452,11 +448,7 @@ async function loadSettings() {
$('#sd_wand_visible').prop('checked', extension_settings.sd.wand_visible);
$('#sd_command_visible').prop('checked', extension_settings.sd.command_visible);
$('#sd_interactive_visible').prop('checked', extension_settings.sd.interactive_visible);
$('#sd_stability_key').val(extension_settings.sd.stability_key);
$('#sd_stability_engine').val(extension_settings.sd.stability_engine);
$('#sd_stability_style_preset').val(extension_settings.sd.stability_style_preset);
$('#sd_stability_aspect_ratio').val(extension_settings.sd.stability_aspect_ratio);
$('#sd_stability_output_format').val(extension_settings.sd.stability_output_format);
for (const style of extension_settings.sd.styles) {
const option = document.createElement('option');
@ -684,7 +676,7 @@ async function refinePrompt(prompt, allowExpand, isNegative = false) {
const refinedPrompt = await callGenericPopup(text + 'Press "Cancel" to abort the image generation.', POPUP_TYPE.INPUT, prompt.trim(), { rows: 5, okButton: 'Continue' });
if (refinedPrompt) {
return refinedPrompt;
return String(refinedPrompt);
} else {
throw new Error('Generation aborted by user.');
}
@ -1097,14 +1089,19 @@ function onComfyWorkflowChange() {
extension_settings.sd.comfy_workflow = $('#sd_comfy_workflow').find(':selected').val();
saveSettingsDebounced();
}
function onStabilityKeyInput() {
extension_settings.sd.stability_key = $('#sd_stability_key').val();
saveSettingsDebounced();
}
function onStabilityEngineChange() {
extension_settings.sd.stability_engine = $('#sd_stability_engine').val();
saveSettingsDebounced();
async function onStabilityKeyClick() {
const popupText = 'Stability AI API Key:';
const key = await callGenericPopup(popupText, POPUP_TYPE.INPUT);
if (!key) {
return;
}
await writeSecret(SECRET_KEYS.STABILITY, String(key));
toastr.success('API Key saved');
await loadSettingOptions();
}
function onStabilityStylePresetChange() {
@ -1112,17 +1109,6 @@ function onStabilityStylePresetChange() {
saveSettingsDebounced();
}
function onStabilityAspectRatioChange() {
extension_settings.sd.stability_aspect_ratio = $('#sd_stability_aspect_ratio').val();
saveSettingsDebounced();
}
function onStabilityOutputFormatChange() {
extension_settings.sd.stability_output_format = $('#sd_stability_output_format').val();
saveSettingsDebounced();
}
async function changeComfyWorkflow(_, name) {
name = name.replace(/(\.json)?$/i, '.json');
if ($(`#sd_comfy_workflow > [value="${name}"]`).length > 0) {
@ -1441,6 +1427,9 @@ async function loadSamplers() {
case sources.pollinations:
samplers = ['N/A'];
break;
case sources.stability:
samplers = ['N/A'];
break;
}
for (const sampler of samplers) {
@ -1643,80 +1632,9 @@ async function loadModels() {
}
}
async function generateStabilityImage(prompt, negativePrompt) {
const payload = {
prompt: prompt,
negative_prompt: negativePrompt,
width: extension_settings.sd.width,
height: extension_settings.sd.height,
seed: extension_settings.sd.seed >= 0 ? extension_settings.sd.seed : undefined,
style_preset: extension_settings.sd.stability_style_preset,
output_format: extension_settings.sd.stability_output_format,
};
const formData = new FormData();
for (const [key, value] of Object.entries(payload)) {
if (value !== undefined) {
formData.append(key, String(value));
}
}
let apiUrl;
switch (extension_settings.sd.model) {
case 'stable-image-ultra':
apiUrl = 'https://api.stability.ai/v2beta/stable-image/generate/ultra';
break;
case 'stable-image-core':
apiUrl = 'https://api.stability.ai/v2beta/stable-image/generate/core';
break;
case 'stable-diffusion-3':
apiUrl = 'https://api.stability.ai/v2beta/stable-image/generate/sd3';
break;
default:
throw new Error('Invalid Stability AI model selected');
}
try {
const response = await fetch(apiUrl, {
method: 'POST',
headers: {
'Authorization': `Bearer ${extension_settings.sd.stability_key}`,
'Accept': 'image/*',
},
body: formData,
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`HTTP ${response.status}: ${errorText}`);
}
const arrayBuffer = await response.arrayBuffer();
const base64Image = arrayBufferToBase64(arrayBuffer);
return {
format: extension_settings.sd.stability_output_format,
data: base64Image,
};
} catch (error) {
console.error('Error generating image with Stability AI:', error);
throw error;
}
}
function arrayBufferToBase64(buffer) {
let binary = '';
const bytes = new Uint8Array(buffer);
const len = bytes.byteLength;
for (let i = 0; i < len; i++) {
binary += String.fromCharCode(bytes[i]);
}
return btoa(binary);
}
async function loadStabilityModels() {
$('#sd_stability_key').toggleClass('success', !!secret_state[SECRET_KEYS.STABILITY]);
return [
{ value: 'stable-image-ultra', text: 'Stable Image Ultra' },
{ value: 'stable-image-core', text: 'Stable Image Core' },
@ -2055,6 +1973,9 @@ async function loadSchedulers() {
case sources.comfy:
schedulers = await loadComfySchedulers();
break;
case sources.stability:
schedulers = ['N/A'];
break;
}
for (const scheduler of schedulers) {
@ -2128,6 +2049,9 @@ async function loadVaes() {
case sources.comfy:
vaes = await loadComfyVaes();
break;
case sources.stability:
vaes = ['N/A'];
break;
}
for (const vae of vaes) {
@ -2611,7 +2535,6 @@ async function sendGenerationRequest(generationType, prompt, additionalNegativeP
case sources.stability:
result = await generateStabilityImage(prefixedPrompt, negativePrompt);
break;
}
if (!result.data) {
@ -2635,6 +2558,12 @@ async function sendGenerationRequest(generationType, prompt, additionalNegativeP
return base64Image;
}
/**
* Generates an image using the TogetherAI API.
* @param {string} prompt - The main instruction used to guide the image generation.
* @param {string} negativePrompt - The instruction used to restrict the image generation.
* @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete.
*/
async function generateTogetherAIImage(prompt, negativePrompt) {
const result = await fetch('/api/sd/together/generate', {
method: 'POST',
@ -2659,6 +2588,12 @@ async function generateTogetherAIImage(prompt, negativePrompt) {
}
}
/**
* Generates an image using the Pollinations API.
* @param {string} prompt - The main instruction used to guide the image generation.
* @param {string} negativePrompt - The instruction used to restrict the image generation.
* @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete.
*/
async function generatePollinationsImage(prompt, negativePrompt) {
const result = await fetch('/api/sd/pollinations/generate', {
method: 'POST',
@ -2727,6 +2662,86 @@ async function generateExtrasImage(prompt, negativePrompt) {
}
}
/**
* Gets an aspect ratio for Stability that is the closest to the given width and height.
* @param {number} width Target width
* @param {number} height Target height
* @returns {string} Closest aspect ratio as a string
*/
function getClosestAspectRatio(width, height) {
const aspectRatios = {
'16:9': 16 / 9,
'1:1': 1,
'21:9': 21 / 9,
'2:3': 2 / 3,
'3:2': 3 / 2,
'4:5': 4 / 5,
'5:4': 5 / 4,
'9:16': 9 / 16,
'9:21': 9 / 21,
};
const aspectRatio = width / height;
let closestAspectRatio = Object.keys(aspectRatios)[0];
let minDiff = Math.abs(aspectRatio - aspectRatios[closestAspectRatio]);
for (const key in aspectRatios) {
const diff = Math.abs(aspectRatio - aspectRatios[key]);
if (diff < minDiff) {
minDiff = diff;
closestAspectRatio = key;
}
}
return closestAspectRatio;
}
/**
* Generates an image using Stability AI.
* @param {string} prompt - The main instruction used to guide the image generation.
* @param {string} negativePrompt - The instruction used to restrict the image generation.
* @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete.
*/
async function generateStabilityImage(prompt, negativePrompt) {
const IMAGE_FORMAT = 'png';
const PROMPT_LIMIT = 10000;
try {
const response = await fetch('/api/sd/stability/generate', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({
model: extension_settings.sd.model,
payload: {
prompt: prompt.slice(0, PROMPT_LIMIT),
negative_prompt: negativePrompt.slice(0, PROMPT_LIMIT),
aspect_ratio: getClosestAspectRatio(extension_settings.sd.width, extension_settings.sd.height),
seed: extension_settings.sd.seed >= 0 ? extension_settings.sd.seed : undefined,
style_preset: extension_settings.sd.stability_style_preset,
output_format: IMAGE_FORMAT,
},
}),
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`HTTP ${response.status}: ${errorText}`);
}
const blob = await response.blob();
const base64Image = await getBase64Async(blob);
return {
format: IMAGE_FORMAT,
data: base64Image,
};
} catch (error) {
console.error('Error generating image with Stability AI:', error);
throw error;
}
}
/**
* Generates a "horde" image using the provided prompt and configuration settings.
*
@ -3356,7 +3371,7 @@ function isValidState() {
case sources.pollinations:
return true;
case sources.stability:
return !!extension_settings.sd.stability_key;
return secret_state[SECRET_KEYS.STABILITY];
}
}
@ -3584,12 +3599,9 @@ jQuery(async () => {
$('#sd_command_visible').on('input', onCommandVisibleInput);
$('#sd_interactive_visible').on('input', onInteractiveVisibleInput);
$('#sd_swap_dimensions').on('click', onSwapDimensionsClick);
$('#sd_stability_key').on('input', onStabilityKeyInput);
$('#sd_stability_engine').on('change', onStabilityEngineChange);
$('#sd_stability_key').on('click', onStabilityKeyClick);
$('#sd_stability_style_preset').on('change', onStabilityStylePresetChange);
$('#sd_stability_aspect_ratio').on('change', onStabilityAspectRatioChange);
$('#sd_stability_output_format').on('change', onStabilityOutputFormatChange);
$('.sd_settings .inline-drawer-toggle').on('click', function () {
initScrollHeight($('#sd_prompt_prefix'));
initScrollHeight($('#sd_negative_prompt'));

View File

@ -44,10 +44,10 @@
<option value="openai">OpenAI (DALL-E)</option>
<option value="pollinations">Pollinations</option>
<option value="vlad">SD.Next (vladmandic)</option>
<option value="stability">Stability AI</option>
<option value="auto">Stable Diffusion Web UI (AUTOMATIC1111)</option>
<option value="horde">Stable Horde</option>
<option value="togetherai">TogetherAI</option>
<option value="stability">Stability AI</option>
</select>
<div data-sd-source="auto">
<label for="sd_auto_url">SD Web UI URL</label>
@ -191,27 +191,22 @@
</div>
</div>
<div data-sd-source="stability">
<label for="sd_stability_key">API Key</label>
<div class="flex-container flexnowrap">
<input id="sd_stability_key" type="password" class="text_pole flex1" placeholder="Enter your Stability AI API key" />
<div id="sd_stability_validate" class="menu_button menu_button_icon">
<i class="fa-solid fa-check"></i>
<span data-i18n="Connect">
Connect
</span>
<div class="flex-container flexnowrap alignItemsBaseline marginBot5">
<strong class="flex1" data-i18n="API Key">API Key</strong>
<div id="sd_stability_key" class="menu_button menu_button_icon">
<i class="fa-fw fa-solid fa-key"></i>
<span data-i18n="Click to set">Click to set</span>
</div>
</div>
<i>You can find your API key in the Stability AI dashboard.</i>
<div class="marginBot5">
<i data-i18n="You can find your API key in the Stability AI dashboard.">
You can find your API key in the Stability AI dashboard.
</i>
</div>
<div class="flex-container">
<div class="flex1">
<label for="sd_stability_engine">Engine</label>
<select id="sd_stability_engine">
<option value="v2beta">V2beta Image Generation</option>
</select>
</div>
<div class="flex1">
<label for="sd_stability_style_preset">Style Preset</label>
<label for="sd_stability_style_preset" data-i18n="Style Preset">Style Preset</label>
<select id="sd_stability_style_preset">
<option value="anime">Anime</option>
<option value="3d-model">3D Model</option>
@ -233,39 +228,7 @@
</select>
</div>
</div>
<div class="flex-container">
<div class="flex1">
<label for="sd_stability_aspect_ratio">Aspect Ratio</label>
<select id="sd_stability_aspect_ratio">
<option value="16:9">16:9</option>
<option value="1:1">1:1</option>
<option value="21:9">21:9</option>
<option value="2:3">2:3</option>
<option value="3:2">3:2</option>
<option value="4:5">4:5</option>
<option value="5:4">5:4</option>
<option value="9:16">9:16</option>
<option value="9:21">9:21</option>
</select>
</div>
<div class="flex1">
<label for="sd_stability_seed">Seed</label>
<input id="sd_stability_seed" type="number" class="text_pole" value="0" min="0" max="4294967295" />
</div>
</div>
<div class="flex-container">
<div class="flex1">
<label for="sd_stability_output_format">Output Format</label>
<select id="sd_stability_output_format">
<option value="png">PNG</option>
<option value="webp">WebP</option>
<option value="jpeg">JPEG</option>
</select>
</div>
</div>
</div>
</div>
<div class="flex-container">
<div class="flex1">
<label for="sd_model" data-i18n="Model">Model</label>
@ -415,7 +378,7 @@
</label>
</div>
<div data-sd-source="novel,togetherai,pollinations,comfy,drawthings,vlad,auto,horde,extras" class="marginTop5">
<div data-sd-source="novel,togetherai,pollinations,comfy,drawthings,vlad,auto,horde,extras,stability" class="marginTop5">
<label for="sd_seed">
<span data-i18n="Seed">Seed</span>
<small data-i18n="(-1 for random)">(-1 for random)</small>

View File

@ -31,6 +31,7 @@ export const SECRET_KEYS = {
FEATHERLESS: 'api_key_featherless',
ZEROONEAI: 'api_key_01ai',
HUGGINGFACE: 'api_key_huggingface',
STABILITY: 'api_key_stability',
};
const INPUT_MAP = {

View File

@ -43,6 +43,7 @@ const SECRET_KEYS = {
FEATHERLESS: 'api_key_featherless',
ZEROONEAI: 'api_key_01ai',
HUGGINGFACE: 'api_key_huggingface',
STABILITY: 'api_key_stability',
};
// These are the keys that are safe to expose, even if allowKeysExposure is false

View File

@ -7,6 +7,7 @@ const path = require('path');
const writeFileAtomicSync = require('write-file-atomic').sync;
const { jsonParser } = require('../express-common');
const { readSecret, SECRET_KEYS } = require('./secrets.js');
const FormData = require('form-data');
/**
* Sanitizes a string.
@ -793,9 +794,68 @@ pollinations.post('/generate', jsonParser, async (request, response) => {
}
});
const stability = express.Router();
stability.post('/generate', jsonParser, async (request, response) => {
try {
const key = readSecret(request.user.directories, SECRET_KEYS.STABILITY);
if (!key) {
console.log('Stability AI key not found.');
return response.sendStatus(400);
}
const { payload, model } = request.body;
const formData = new FormData();
for (const [key, value] of Object.entries(payload)) {
if (value !== undefined) {
formData.append(key, String(value));
}
}
let apiUrl;
switch (model) {
case 'stable-image-ultra':
apiUrl = 'https://api.stability.ai/v2beta/stable-image/generate/ultra';
break;
case 'stable-image-core':
apiUrl = 'https://api.stability.ai/v2beta/stable-image/generate/core';
break;
case 'stable-diffusion-3':
apiUrl = 'https://api.stability.ai/v2beta/stable-image/generate/sd3';
break;
default:
throw new Error('Invalid Stability AI model selected');
}
const result = await fetch(apiUrl, {
method: 'POST',
headers: {
'Authorization': `Bearer ${key}`,
'Accept': 'image/*',
},
body: formData,
});
if (!result.ok) {
const text = await result.text();
console.log('Stability AI returned an error.', result.status, result.statusText, text);
return response.sendStatus(500);
}
const buffer = await result.buffer();
return response.send(buffer);
} catch (error) {
console.log(error);
return response.sendStatus(500);
}
});
router.use('/comfy', comfy);
router.use('/together', together);
router.use('/drawthings', drawthings);
router.use('/pollinations', pollinations);
router.use('/stability', stability);
module.exports = { router };