Add DALL-E to OpenAI plugin

This commit is contained in:
Cohee 2023-11-06 21:47:00 +02:00
parent 57e845d0d7
commit 1896732f17
4 changed files with 173 additions and 7 deletions

View File

@ -36,6 +36,7 @@ const sources = {
auto: 'auto',
novel: 'novel',
vlad: 'vlad',
openai: 'openai',
}
const generationMode = {
@ -206,6 +207,10 @@ const defaultSettings = {
novel_upscale_ratio: 1.0,
novel_anlas_guard: false,
// OpenAI settings
openai_style: 'vivid',
openai_quality: 'standard',
style: 'Default',
styles: defaultStyles,
}
@ -341,6 +346,8 @@ async function loadSettings() {
$('#sd_vlad_url').val(extension_settings.sd.vlad_url);
$('#sd_vlad_auth').val(extension_settings.sd.vlad_auth);
$('#sd_interactive_mode').prop('checked', extension_settings.sd.interactive_mode);
$('#sd_openai_style').val(extension_settings.sd.openai_style);
$('#sd_openai_quality').val(extension_settings.sd.openai_quality);
for (const style of extension_settings.sd.styles) {
const option = document.createElement('option');
@ -601,6 +608,16 @@ async function onSourceChange() {
await Promise.all([loadModels(), loadSamplers()]);
}
async function onOpenAiStyleSelect() {
extension_settings.sd.openai_style = String($('#sd_openai_style').find(':selected').val());
saveSettingsDebounced();
}
async function onOpenAiQualitySelect() {
extension_settings.sd.openai_quality = String($('#sd_openai_quality').find(':selected').val());
saveSettingsDebounced();
}
async function onViewAnlasClick() {
const result = await loadNovelSubscriptionData();
@ -746,7 +763,7 @@ async function onModelChange() {
extension_settings.sd.model = $('#sd_model').find(':selected').val();
saveSettingsDebounced();
const cloudSources = [sources.horde, sources.novel];
const cloudSources = [sources.horde, sources.novel, sources.openai];
if (cloudSources.includes(extension_settings.sd.source)) {
return;
@ -874,6 +891,9 @@ async function loadSamplers() {
case sources.vlad:
samplers = await loadVladSamplers();
break;
case sources.openai:
samplers = await loadOpenAiSamplers();
break;
}
for (const sampler of samplers) {
@ -939,6 +959,10 @@ async function loadAutoSamplers() {
}
}
async function loadOpenAiSamplers() {
return ['N/A'];
}
async function loadVladSamplers() {
if (!extension_settings.sd.vlad_url) {
return [];
@ -999,6 +1023,9 @@ async function loadModels() {
case sources.vlad:
models = await loadVladModels();
break;
case sources.openai:
models = await loadOpenAiModels();
break;
}
for (const model of models) {
@ -1096,6 +1123,13 @@ async function loadAutoModels() {
}
}
async function loadOpenAiModels() {
return [
{ value: 'dall-e-2', text: 'DALL-E 2' },
{ value: 'dall-e-3', text: 'DALL-E 3' },
];
}
async function loadVladModels() {
if (!extension_settings.sd.vlad_url) {
return [];
@ -1368,13 +1402,17 @@ async function sendGenerationRequest(generationType, prompt, characterName = nul
case sources.novel:
result = await generateNovelImage(prefixedPrompt);
break;
case sources.openai:
result = await generateOpenAiImage(prefixedPrompt);
break;
}
if (!result.data) {
throw new Error();
throw new Error('Endpoint did not return image data.');
}
} catch (err) {
toastr.error('Image generation failed. Please try again', 'Stable Diffusion');
console.error(err);
toastr.error('Image generation failed. Please try again.' + '\n\n' + String(err), 'Stable Diffusion');
return;
}
@ -1425,7 +1463,8 @@ async function generateExtrasImage(prompt) {
const data = await result.json();
return { format: 'jpg', data: data.image };
} else {
throw new Error();
const text = await result.text();
throw new Error(text);
}
}
@ -1459,7 +1498,8 @@ async function generateHordeImage(prompt) {
const data = await result.text();
return { format: 'webp', data: data };
} else {
throw new Error();
const text = await result.text();
throw new Error(text);
}
}
@ -1500,7 +1540,8 @@ async function generateAutoImage(prompt) {
const data = await result.json();
return { format: 'png', data: data.images[0] };
} else {
throw new Error();
const text = await result.text();
throw new Error(text);
}
}
@ -1533,7 +1574,8 @@ async function generateNovelImage(prompt) {
const data = await result.text();
return { format: 'png', data: data };
} else {
throw new Error();
const text = await result.text();
throw new Error(text);
}
}
@ -1592,6 +1634,61 @@ function getNovelParams() {
return { steps, width, height };
}
async function generateOpenAiImage(prompt) {
const dalle2PromptLimit = 1000;
const dalle3PromptLimit = 4000;
const isDalle2 = extension_settings.sd.model === 'dall-e-2';
const isDalle3 = extension_settings.sd.model === 'dall-e-3';
if (isDalle2 && prompt.length > dalle2PromptLimit) {
prompt = prompt.substring(0, dalle2PromptLimit);
}
if (isDalle3 && prompt.length > dalle3PromptLimit) {
prompt = prompt.substring(0, dalle3PromptLimit);
}
let width = 1024;
let height = 1024;
let aspectRatio = extension_settings.sd.width / extension_settings.sd.height;
if (isDalle3 && aspectRatio < 1) {
height = 1792;
}
if (isDalle3 && aspectRatio > 1) {
width = 1792;
}
if (isDalle2 && (extension_settings.sd.width <= 512 && extension_settings.sd.height <= 512)) {
width = 512;
height = 512;
}
const result = await fetch('/api/openai/generate-image', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({
prompt: prompt,
model: extension_settings.sd.model,
size: `${width}x${height}`,
n: 1,
quality: isDalle3 ? extension_settings.sd.openai_quality : undefined,
style: isDalle3 ? extension_settings.sd.openai_style : undefined,
response_format: 'b64_json',
}),
});
if (result.ok) {
const data = await result.json();
return { format: 'png', data: data?.data[0]?.b64_json };
} else {
const text = await result.text();
throw new Error(text);
}
}
async function sendMessage(prompt, image, generationType) {
const context = getContext();
const messageText = `[${context.name2} sends a picture that contains: ${prompt}]`;
@ -1683,6 +1780,8 @@ function isValidState() {
return !!extension_settings.sd.vlad_url;
case sources.novel:
return secret_state[SECRET_KEYS.NOVEL];
case sources.openai:
return secret_state[SECRET_KEYS.OPENAI];
}
}
@ -1826,6 +1925,8 @@ jQuery(async () => {
$('#sd_save_style').on('click', onSaveStyleClick);
$('#sd_character_prompt_block').hide();
$('#sd_interactive_mode').on('input', onInteractiveModeInput);
$('#sd_openai_style').on('change', onOpenAiStyleSelect);
$('#sd_openai_quality').on('change', onOpenAiQualitySelect);
$('.sd_settings .inline-drawer-toggle').on('click', function () {
initScrollHeight($("#sd_prompt_prefix"));

View File

@ -29,6 +29,7 @@
<option value="auto">Stable Diffusion Web UI (AUTOMATIC1111)</option>
<option value="vlad">SD.Next (vladmandic)</option>
<option value="novel">NovelAI Diffusion</option>
<option value="openai">OpenAI (DALL-E)</option>
</select>
<div data-sd-source="auto">
<label for="sd_auto_url">SD Web UI URL</label>
@ -96,6 +97,21 @@
</div>
<i>Hint: Save an API key in the NovelAI API settings to use it here.</i>
</div>
<div data-sd-source="openai">
<small>These settings only apply to DALL-E 3</small>
<div class="flex-container">
<label for="sd_openai_style">Image Style</label>
<select id="sd_openai_style">
<option value="vivid">Vivid</option>
<option value="natural">Natural</option>
</select>
<label for="sd_openai_quality">Image Quality</label>
<select id="sd_openai_quality">
<option value="standard">Standard</option>
<option value="hd">HD</option>
</select>
</div>
</div>
<label for="sd_scale">CFG Scale (<span id="sd_scale_value"></span>)</label>
<input id="sd_scale" type="range" min="{{scale_min}}" max="{{scale_max}}" step="{{scale_step}}" value="{{scale}}" />
<label for="sd_steps">Sampling steps (<span id="sd_steps_value"></span>)</label>

View File

@ -3429,6 +3429,9 @@ async function postAsync(url, args) { return fetchJSON(url, { method: 'POST', ti
// ** END **
// OpenAI API
require('./src/openai').registerEndpoints(app, jsonParser);
// Tokenizers
require('./src/tokenizers').registerEndpoints(app, jsonParser);

46
src/openai.js Normal file
View File

@ -0,0 +1,46 @@
const { readSecret, SECRET_KEYS } = require("./secrets");
/**
* Registers the OpenAI endpoints.
* @param {import("express").Express} app
* @param {any} jsonParser
*/
function registerEndpoints(app, jsonParser) {
app.post('/api/openai/generate-image', jsonParser, async (request, response) => {
try {
const key = readSecret(SECRET_KEYS.OPENAI);
if (!key) {
console.log('No OpenAI key found');
return response.sendStatus(401);
}
console.log('OpenAI request', request.body);
const result = await fetch('https://api.openai.com/v1/images/generations', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${key}`,
},
body: JSON.stringify(request.body),
});
if (!result.ok) {
const text = await result.text();
console.log('OpenAI request failed', result.statusText, text);
return response.status(500).send(text);
}
const data = await result.json();
return response.json(data);
} catch (error) {
console.error(error);
response.status(500).send('Internal server error');
}
});
}
module.exports = {
registerEndpoints,
};