Add DALL-E to OpenAI plugin

This commit is contained in:
Cohee
2023-11-06 21:47:00 +02:00
parent 57e845d0d7
commit 1896732f17
4 changed files with 173 additions and 7 deletions

View File

@ -36,6 +36,7 @@ const sources = {
auto: 'auto',
novel: 'novel',
vlad: 'vlad',
openai: 'openai',
}
const generationMode = {
@ -206,6 +207,10 @@ const defaultSettings = {
novel_upscale_ratio: 1.0,
novel_anlas_guard: false,
// OpenAI settings
openai_style: 'vivid',
openai_quality: 'standard',
style: 'Default',
styles: defaultStyles,
}
@ -341,6 +346,8 @@ async function loadSettings() {
$('#sd_vlad_url').val(extension_settings.sd.vlad_url);
$('#sd_vlad_auth').val(extension_settings.sd.vlad_auth);
$('#sd_interactive_mode').prop('checked', extension_settings.sd.interactive_mode);
$('#sd_openai_style').val(extension_settings.sd.openai_style);
$('#sd_openai_quality').val(extension_settings.sd.openai_quality);
for (const style of extension_settings.sd.styles) {
const option = document.createElement('option');
@ -601,6 +608,16 @@ async function onSourceChange() {
await Promise.all([loadModels(), loadSamplers()]);
}
async function onOpenAiStyleSelect() {
extension_settings.sd.openai_style = String($('#sd_openai_style').find(':selected').val());
saveSettingsDebounced();
}
async function onOpenAiQualitySelect() {
extension_settings.sd.openai_quality = String($('#sd_openai_quality').find(':selected').val());
saveSettingsDebounced();
}
async function onViewAnlasClick() {
const result = await loadNovelSubscriptionData();
@ -746,7 +763,7 @@ async function onModelChange() {
extension_settings.sd.model = $('#sd_model').find(':selected').val();
saveSettingsDebounced();
const cloudSources = [sources.horde, sources.novel];
const cloudSources = [sources.horde, sources.novel, sources.openai];
if (cloudSources.includes(extension_settings.sd.source)) {
return;
@ -874,6 +891,9 @@ async function loadSamplers() {
case sources.vlad:
samplers = await loadVladSamplers();
break;
case sources.openai:
samplers = await loadOpenAiSamplers();
break;
}
for (const sampler of samplers) {
@ -939,6 +959,10 @@ async function loadAutoSamplers() {
}
}
async function loadOpenAiSamplers() {
return ['N/A'];
}
async function loadVladSamplers() {
if (!extension_settings.sd.vlad_url) {
return [];
@ -999,6 +1023,9 @@ async function loadModels() {
case sources.vlad:
models = await loadVladModels();
break;
case sources.openai:
models = await loadOpenAiModels();
break;
}
for (const model of models) {
@ -1096,6 +1123,13 @@ async function loadAutoModels() {
}
}
async function loadOpenAiModels() {
return [
{ value: 'dall-e-2', text: 'DALL-E 2' },
{ value: 'dall-e-3', text: 'DALL-E 3' },
];
}
async function loadVladModels() {
if (!extension_settings.sd.vlad_url) {
return [];
@ -1368,13 +1402,17 @@ async function sendGenerationRequest(generationType, prompt, characterName = nul
case sources.novel:
result = await generateNovelImage(prefixedPrompt);
break;
case sources.openai:
result = await generateOpenAiImage(prefixedPrompt);
break;
}
if (!result.data) {
throw new Error();
throw new Error('Endpoint did not return image data.');
}
} catch (err) {
toastr.error('Image generation failed. Please try again', 'Stable Diffusion');
console.error(err);
toastr.error('Image generation failed. Please try again.' + '\n\n' + String(err), 'Stable Diffusion');
return;
}
@ -1425,7 +1463,8 @@ async function generateExtrasImage(prompt) {
const data = await result.json();
return { format: 'jpg', data: data.image };
} else {
throw new Error();
const text = await result.text();
throw new Error(text);
}
}
@ -1459,7 +1498,8 @@ async function generateHordeImage(prompt) {
const data = await result.text();
return { format: 'webp', data: data };
} else {
throw new Error();
const text = await result.text();
throw new Error(text);
}
}
@ -1500,7 +1540,8 @@ async function generateAutoImage(prompt) {
const data = await result.json();
return { format: 'png', data: data.images[0] };
} else {
throw new Error();
const text = await result.text();
throw new Error(text);
}
}
@ -1533,7 +1574,8 @@ async function generateNovelImage(prompt) {
const data = await result.text();
return { format: 'png', data: data };
} else {
throw new Error();
const text = await result.text();
throw new Error(text);
}
}
@ -1592,6 +1634,61 @@ function getNovelParams() {
return { steps, width, height };
}
async function generateOpenAiImage(prompt) {
const dalle2PromptLimit = 1000;
const dalle3PromptLimit = 4000;
const isDalle2 = extension_settings.sd.model === 'dall-e-2';
const isDalle3 = extension_settings.sd.model === 'dall-e-3';
if (isDalle2 && prompt.length > dalle2PromptLimit) {
prompt = prompt.substring(0, dalle2PromptLimit);
}
if (isDalle3 && prompt.length > dalle3PromptLimit) {
prompt = prompt.substring(0, dalle3PromptLimit);
}
let width = 1024;
let height = 1024;
let aspectRatio = extension_settings.sd.width / extension_settings.sd.height;
if (isDalle3 && aspectRatio < 1) {
height = 1792;
}
if (isDalle3 && aspectRatio > 1) {
width = 1792;
}
if (isDalle2 && (extension_settings.sd.width <= 512 && extension_settings.sd.height <= 512)) {
width = 512;
height = 512;
}
const result = await fetch('/api/openai/generate-image', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({
prompt: prompt,
model: extension_settings.sd.model,
size: `${width}x${height}`,
n: 1,
quality: isDalle3 ? extension_settings.sd.openai_quality : undefined,
style: isDalle3 ? extension_settings.sd.openai_style : undefined,
response_format: 'b64_json',
}),
});
if (result.ok) {
const data = await result.json();
return { format: 'png', data: data?.data[0]?.b64_json };
} else {
const text = await result.text();
throw new Error(text);
}
}
async function sendMessage(prompt, image, generationType) {
const context = getContext();
const messageText = `[${context.name2} sends a picture that contains: ${prompt}]`;
@ -1683,6 +1780,8 @@ function isValidState() {
return !!extension_settings.sd.vlad_url;
case sources.novel:
return secret_state[SECRET_KEYS.NOVEL];
case sources.openai:
return secret_state[SECRET_KEYS.OPENAI];
}
}
@ -1826,6 +1925,8 @@ jQuery(async () => {
$('#sd_save_style').on('click', onSaveStyleClick);
$('#sd_character_prompt_block').hide();
$('#sd_interactive_mode').on('input', onInteractiveModeInput);
$('#sd_openai_style').on('change', onOpenAiStyleSelect);
$('#sd_openai_quality').on('change', onOpenAiQualitySelect);
$('.sd_settings .inline-drawer-toggle').on('click', function () {
initScrollHeight($("#sd_prompt_prefix"));