From 1896732f176c53feb87138ffadf3dd3ee3ab7ea3 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Mon, 6 Nov 2023 21:47:00 +0200 Subject: [PATCH] Add DALL-E to OpenAI plugin --- .../extensions/stable-diffusion/index.js | 115 ++++++++++++++++-- .../extensions/stable-diffusion/settings.html | 16 +++ server.js | 3 + src/openai.js | 46 +++++++ 4 files changed, 173 insertions(+), 7 deletions(-) create mode 100644 src/openai.js diff --git a/public/scripts/extensions/stable-diffusion/index.js b/public/scripts/extensions/stable-diffusion/index.js index 2e5f07c73..8fbc3f491 100644 --- a/public/scripts/extensions/stable-diffusion/index.js +++ b/public/scripts/extensions/stable-diffusion/index.js @@ -36,6 +36,7 @@ const sources = { auto: 'auto', novel: 'novel', vlad: 'vlad', + openai: 'openai', } const generationMode = { @@ -206,6 +207,10 @@ const defaultSettings = { novel_upscale_ratio: 1.0, novel_anlas_guard: false, + // OpenAI settings + openai_style: 'vivid', + openai_quality: 'standard', + style: 'Default', styles: defaultStyles, } @@ -341,6 +346,8 @@ async function loadSettings() { $('#sd_vlad_url').val(extension_settings.sd.vlad_url); $('#sd_vlad_auth').val(extension_settings.sd.vlad_auth); $('#sd_interactive_mode').prop('checked', extension_settings.sd.interactive_mode); + $('#sd_openai_style').val(extension_settings.sd.openai_style); + $('#sd_openai_quality').val(extension_settings.sd.openai_quality); for (const style of extension_settings.sd.styles) { const option = document.createElement('option'); @@ -601,6 +608,16 @@ async function onSourceChange() { await Promise.all([loadModels(), loadSamplers()]); } +async function onOpenAiStyleSelect() { + extension_settings.sd.openai_style = String($('#sd_openai_style').find(':selected').val()); + saveSettingsDebounced(); +} + +async function onOpenAiQualitySelect() { + extension_settings.sd.openai_quality = String($('#sd_openai_quality').find(':selected').val()); + saveSettingsDebounced(); +} + async function onViewAnlasClick() { const result = await loadNovelSubscriptionData(); @@ -746,7 +763,7 @@ async function onModelChange() { extension_settings.sd.model = $('#sd_model').find(':selected').val(); saveSettingsDebounced(); - const cloudSources = [sources.horde, sources.novel]; + const cloudSources = [sources.horde, sources.novel, sources.openai]; if (cloudSources.includes(extension_settings.sd.source)) { return; @@ -874,6 +891,9 @@ async function loadSamplers() { case sources.vlad: samplers = await loadVladSamplers(); break; + case sources.openai: + samplers = await loadOpenAiSamplers(); + break; } for (const sampler of samplers) { @@ -939,6 +959,10 @@ async function loadAutoSamplers() { } } +async function loadOpenAiSamplers() { + return ['N/A']; +} + async function loadVladSamplers() { if (!extension_settings.sd.vlad_url) { return []; @@ -999,6 +1023,9 @@ async function loadModels() { case sources.vlad: models = await loadVladModels(); break; + case sources.openai: + models = await loadOpenAiModels(); + break; } for (const model of models) { @@ -1096,6 +1123,13 @@ async function loadAutoModels() { } } +async function loadOpenAiModels() { + return [ + { value: 'dall-e-2', text: 'DALL-E 2' }, + { value: 'dall-e-3', text: 'DALL-E 3' }, + ]; +} + async function loadVladModels() { if (!extension_settings.sd.vlad_url) { return []; @@ -1368,13 +1402,17 @@ async function sendGenerationRequest(generationType, prompt, characterName = nul case sources.novel: result = await generateNovelImage(prefixedPrompt); break; + case sources.openai: + result = await generateOpenAiImage(prefixedPrompt); + break; } if (!result.data) { - throw new Error(); + throw new Error('Endpoint did not return image data.'); } } catch (err) { - toastr.error('Image generation failed. Please try again', 'Stable Diffusion'); + console.error(err); + toastr.error('Image generation failed. Please try again.' + '\n\n' + String(err), 'Stable Diffusion'); return; } @@ -1425,7 +1463,8 @@ async function generateExtrasImage(prompt) { const data = await result.json(); return { format: 'jpg', data: data.image }; } else { - throw new Error(); + const text = await result.text(); + throw new Error(text); } } @@ -1459,7 +1498,8 @@ async function generateHordeImage(prompt) { const data = await result.text(); return { format: 'webp', data: data }; } else { - throw new Error(); + const text = await result.text(); + throw new Error(text); } } @@ -1500,7 +1540,8 @@ async function generateAutoImage(prompt) { const data = await result.json(); return { format: 'png', data: data.images[0] }; } else { - throw new Error(); + const text = await result.text(); + throw new Error(text); } } @@ -1533,7 +1574,8 @@ async function generateNovelImage(prompt) { const data = await result.text(); return { format: 'png', data: data }; } else { - throw new Error(); + const text = await result.text(); + throw new Error(text); } } @@ -1592,6 +1634,61 @@ function getNovelParams() { return { steps, width, height }; } +async function generateOpenAiImage(prompt) { + const dalle2PromptLimit = 1000; + const dalle3PromptLimit = 4000; + + const isDalle2 = extension_settings.sd.model === 'dall-e-2'; + const isDalle3 = extension_settings.sd.model === 'dall-e-3'; + + if (isDalle2 && prompt.length > dalle2PromptLimit) { + prompt = prompt.substring(0, dalle2PromptLimit); + } + + if (isDalle3 && prompt.length > dalle3PromptLimit) { + prompt = prompt.substring(0, dalle3PromptLimit); + } + + let width = 1024; + let height = 1024; + let aspectRatio = extension_settings.sd.width / extension_settings.sd.height; + + if (isDalle3 && aspectRatio < 1) { + height = 1792; + } + + if (isDalle3 && aspectRatio > 1) { + width = 1792; + } + + if (isDalle2 && (extension_settings.sd.width <= 512 && extension_settings.sd.height <= 512)) { + width = 512; + height = 512; + } + + const result = await fetch('/api/openai/generate-image', { + method: 'POST', + headers: getRequestHeaders(), + body: JSON.stringify({ + prompt: prompt, + model: extension_settings.sd.model, + size: `${width}x${height}`, + n: 1, + quality: isDalle3 ? extension_settings.sd.openai_quality : undefined, + style: isDalle3 ? extension_settings.sd.openai_style : undefined, + response_format: 'b64_json', + }), + }); + + if (result.ok) { + const data = await result.json(); + return { format: 'png', data: data?.data[0]?.b64_json }; + } else { + const text = await result.text(); + throw new Error(text); + } +} + async function sendMessage(prompt, image, generationType) { const context = getContext(); const messageText = `[${context.name2} sends a picture that contains: ${prompt}]`; @@ -1683,6 +1780,8 @@ function isValidState() { return !!extension_settings.sd.vlad_url; case sources.novel: return secret_state[SECRET_KEYS.NOVEL]; + case sources.openai: + return secret_state[SECRET_KEYS.OPENAI]; } } @@ -1826,6 +1925,8 @@ jQuery(async () => { $('#sd_save_style').on('click', onSaveStyleClick); $('#sd_character_prompt_block').hide(); $('#sd_interactive_mode').on('input', onInteractiveModeInput); + $('#sd_openai_style').on('change', onOpenAiStyleSelect); + $('#sd_openai_quality').on('change', onOpenAiQualitySelect); $('.sd_settings .inline-drawer-toggle').on('click', function () { initScrollHeight($("#sd_prompt_prefix")); diff --git a/public/scripts/extensions/stable-diffusion/settings.html b/public/scripts/extensions/stable-diffusion/settings.html index 5a55fa611..eee6016cd 100644 --- a/public/scripts/extensions/stable-diffusion/settings.html +++ b/public/scripts/extensions/stable-diffusion/settings.html @@ -29,6 +29,7 @@ +
@@ -96,6 +97,21 @@
Hint: Save an API key in the NovelAI API settings to use it here. +
+ These settings only apply to DALL-E 3 +
+ + + + +
+
diff --git a/server.js b/server.js index f2a84df23..704c5720c 100644 --- a/server.js +++ b/server.js @@ -3429,6 +3429,9 @@ async function postAsync(url, args) { return fetchJSON(url, { method: 'POST', ti // ** END ** +// OpenAI API +require('./src/openai').registerEndpoints(app, jsonParser); + // Tokenizers require('./src/tokenizers').registerEndpoints(app, jsonParser); diff --git a/src/openai.js b/src/openai.js new file mode 100644 index 000000000..0bffee92a --- /dev/null +++ b/src/openai.js @@ -0,0 +1,46 @@ +const { readSecret, SECRET_KEYS } = require("./secrets"); + +/** + * Registers the OpenAI endpoints. + * @param {import("express").Express} app + * @param {any} jsonParser + */ +function registerEndpoints(app, jsonParser) { + app.post('/api/openai/generate-image', jsonParser, async (request, response) => { + try { + const key = readSecret(SECRET_KEYS.OPENAI); + + if (!key) { + console.log('No OpenAI key found'); + return response.sendStatus(401); + } + + console.log('OpenAI request', request.body); + + const result = await fetch('https://api.openai.com/v1/images/generations', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${key}`, + }, + body: JSON.stringify(request.body), + }); + + if (!result.ok) { + const text = await result.text(); + console.log('OpenAI request failed', result.statusText, text); + return response.status(500).send(text); + } + + const data = await result.json(); + return response.json(data); + } catch (error) { + console.error(error); + response.status(500).send('Internal server error'); + } + }); +} + +module.exports = { + registerEndpoints, +};