From 1896732f176c53feb87138ffadf3dd3ee3ab7ea3 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Mon, 6 Nov 2023 21:47:00 +0200 Subject: [PATCH] Add DALL-E to OpenAI plugin --- .../extensions/stable-diffusion/index.js | 115 ++++++++++++++++-- .../extensions/stable-diffusion/settings.html | 16 +++ server.js | 3 + src/openai.js | 46 +++++++ 4 files changed, 173 insertions(+), 7 deletions(-) create mode 100644 src/openai.js diff --git a/public/scripts/extensions/stable-diffusion/index.js b/public/scripts/extensions/stable-diffusion/index.js index 2e5f07c73..8fbc3f491 100644 --- a/public/scripts/extensions/stable-diffusion/index.js +++ b/public/scripts/extensions/stable-diffusion/index.js @@ -36,6 +36,7 @@ const sources = { auto: 'auto', novel: 'novel', vlad: 'vlad', + openai: 'openai', } const generationMode = { @@ -206,6 +207,10 @@ const defaultSettings = { novel_upscale_ratio: 1.0, novel_anlas_guard: false, + // OpenAI settings + openai_style: 'vivid', + openai_quality: 'standard', + style: 'Default', styles: defaultStyles, } @@ -341,6 +346,8 @@ async function loadSettings() { $('#sd_vlad_url').val(extension_settings.sd.vlad_url); $('#sd_vlad_auth').val(extension_settings.sd.vlad_auth); $('#sd_interactive_mode').prop('checked', extension_settings.sd.interactive_mode); + $('#sd_openai_style').val(extension_settings.sd.openai_style); + $('#sd_openai_quality').val(extension_settings.sd.openai_quality); for (const style of extension_settings.sd.styles) { const option = document.createElement('option'); @@ -601,6 +608,16 @@ async function onSourceChange() { await Promise.all([loadModels(), loadSamplers()]); } +async function onOpenAiStyleSelect() { + extension_settings.sd.openai_style = String($('#sd_openai_style').find(':selected').val()); + saveSettingsDebounced(); +} + +async function onOpenAiQualitySelect() { + extension_settings.sd.openai_quality = String($('#sd_openai_quality').find(':selected').val()); + saveSettingsDebounced(); +} + async function onViewAnlasClick() { const result = await loadNovelSubscriptionData(); @@ -746,7 +763,7 @@ async function onModelChange() { extension_settings.sd.model = $('#sd_model').find(':selected').val(); saveSettingsDebounced(); - const cloudSources = [sources.horde, sources.novel]; + const cloudSources = [sources.horde, sources.novel, sources.openai]; if (cloudSources.includes(extension_settings.sd.source)) { return; @@ -874,6 +891,9 @@ async function loadSamplers() { case sources.vlad: samplers = await loadVladSamplers(); break; + case sources.openai: + samplers = await loadOpenAiSamplers(); + break; } for (const sampler of samplers) { @@ -939,6 +959,10 @@ async function loadAutoSamplers() { } } +async function loadOpenAiSamplers() { + return ['N/A']; +} + async function loadVladSamplers() { if (!extension_settings.sd.vlad_url) { return []; @@ -999,6 +1023,9 @@ async function loadModels() { case sources.vlad: models = await loadVladModels(); break; + case sources.openai: + models = await loadOpenAiModels(); + break; } for (const model of models) { @@ -1096,6 +1123,13 @@ async function loadAutoModels() { } } +async function loadOpenAiModels() { + return [ + { value: 'dall-e-2', text: 'DALL-E 2' }, + { value: 'dall-e-3', text: 'DALL-E 3' }, + ]; +} + async function loadVladModels() { if (!extension_settings.sd.vlad_url) { return []; @@ -1368,13 +1402,17 @@ async function sendGenerationRequest(generationType, prompt, characterName = nul case sources.novel: result = await generateNovelImage(prefixedPrompt); break; + case sources.openai: + result = await generateOpenAiImage(prefixedPrompt); + break; } if (!result.data) { - throw new Error(); + throw new Error('Endpoint did not return image data.'); } } catch (err) { - toastr.error('Image generation failed. Please try again', 'Stable Diffusion'); + console.error(err); + toastr.error('Image generation failed. Please try again.' + '\n\n' + String(err), 'Stable Diffusion'); return; } @@ -1425,7 +1463,8 @@ async function generateExtrasImage(prompt) { const data = await result.json(); return { format: 'jpg', data: data.image }; } else { - throw new Error(); + const text = await result.text(); + throw new Error(text); } } @@ -1459,7 +1498,8 @@ async function generateHordeImage(prompt) { const data = await result.text(); return { format: 'webp', data: data }; } else { - throw new Error(); + const text = await result.text(); + throw new Error(text); } } @@ -1500,7 +1540,8 @@ async function generateAutoImage(prompt) { const data = await result.json(); return { format: 'png', data: data.images[0] }; } else { - throw new Error(); + const text = await result.text(); + throw new Error(text); } } @@ -1533,7 +1574,8 @@ async function generateNovelImage(prompt) { const data = await result.text(); return { format: 'png', data: data }; } else { - throw new Error(); + const text = await result.text(); + throw new Error(text); } } @@ -1592,6 +1634,61 @@ function getNovelParams() { return { steps, width, height }; } +async function generateOpenAiImage(prompt) { + const dalle2PromptLimit = 1000; + const dalle3PromptLimit = 4000; + + const isDalle2 = extension_settings.sd.model === 'dall-e-2'; + const isDalle3 = extension_settings.sd.model === 'dall-e-3'; + + if (isDalle2 && prompt.length > dalle2PromptLimit) { + prompt = prompt.substring(0, dalle2PromptLimit); + } + + if (isDalle3 && prompt.length > dalle3PromptLimit) { + prompt = prompt.substring(0, dalle3PromptLimit); + } + + let width = 1024; + let height = 1024; + let aspectRatio = extension_settings.sd.width / extension_settings.sd.height; + + if (isDalle3 && aspectRatio < 1) { + height = 1792; + } + + if (isDalle3 && aspectRatio > 1) { + width = 1792; + } + + if (isDalle2 && (extension_settings.sd.width <= 512 && extension_settings.sd.height <= 512)) { + width = 512; + height = 512; + } + + const result = await fetch('/api/openai/generate-image', { + method: 'POST', + headers: getRequestHeaders(), + body: JSON.stringify({ + prompt: prompt, + model: extension_settings.sd.model, + size: `${width}x${height}`, + n: 1, + quality: isDalle3 ? extension_settings.sd.openai_quality : undefined, + style: isDalle3 ? extension_settings.sd.openai_style : undefined, + response_format: 'b64_json', + }), + }); + + if (result.ok) { + const data = await result.json(); + return { format: 'png', data: data?.data[0]?.b64_json }; + } else { + const text = await result.text(); + throw new Error(text); + } +} + async function sendMessage(prompt, image, generationType) { const context = getContext(); const messageText = `[${context.name2} sends a picture that contains: ${prompt}]`; @@ -1683,6 +1780,8 @@ function isValidState() { return !!extension_settings.sd.vlad_url; case sources.novel: return secret_state[SECRET_KEYS.NOVEL]; + case sources.openai: + return secret_state[SECRET_KEYS.OPENAI]; } } @@ -1826,6 +1925,8 @@ jQuery(async () => { $('#sd_save_style').on('click', onSaveStyleClick); $('#sd_character_prompt_block').hide(); $('#sd_interactive_mode').on('input', onInteractiveModeInput); + $('#sd_openai_style').on('change', onOpenAiStyleSelect); + $('#sd_openai_quality').on('change', onOpenAiQualitySelect); $('.sd_settings .inline-drawer-toggle').on('click', function () { initScrollHeight($("#sd_prompt_prefix")); diff --git a/public/scripts/extensions/stable-diffusion/settings.html b/public/scripts/extensions/stable-diffusion/settings.html index 5a55fa611..eee6016cd 100644 --- a/public/scripts/extensions/stable-diffusion/settings.html +++ b/public/scripts/extensions/stable-diffusion/settings.html @@ -29,6 +29,7 @@ +