From 6e0ed8552ff01f1ce445d5d62c2b261fababd0f8 Mon Sep 17 00:00:00 2001 From: Kristan Schlikow Date: Thu, 13 Feb 2025 19:34:34 +0100 Subject: [PATCH 1/3] Add support for FAL.AI as image gen provider --- .../extensions/stable-diffusion/index.js | 65 ++++++++++ .../extensions/stable-diffusion/settings.html | 15 +++ public/scripts/secrets.js | 1 + src/endpoints/secrets.js | 1 + src/endpoints/stable-diffusion.js | 120 ++++++++++++++++++ 5 files changed, 202 insertions(+) diff --git a/public/scripts/extensions/stable-diffusion/index.js b/public/scripts/extensions/stable-diffusion/index.js index 9f9960c7d..6d410a309 100644 --- a/public/scripts/extensions/stable-diffusion/index.js +++ b/public/scripts/extensions/stable-diffusion/index.js @@ -81,6 +81,7 @@ const sources = { huggingface: 'huggingface', nanogpt: 'nanogpt', bfl: 'bfl', + falai: 'falai', }; const initiators = { @@ -1169,6 +1170,10 @@ async function onBflKeyClick() { return onApiKeyClick('BFL API Key:', SECRET_KEYS.BFL); } +async function onFalaiKeyClick() { + return onApiKeyClick('FALAI API Key:', SECRET_KEYS.FALAI); +} + function onBflUpsamplingInput() { extension_settings.sd.bfl_upsampling = !!$('#sd_bfl_upsampling').prop('checked'); saveSettingsDebounced(); @@ -1707,6 +1712,9 @@ async function loadModels() { case sources.bfl: models = await loadBflModels(); break; + case sources.falai: + models = await loadFalaiModels(); + break; } for (const model of models) { @@ -1744,6 +1752,21 @@ async function loadBflModels() { ]; } +async function loadFalaiModels() { + $('#sd_falai_key').toggleClass('success', !!secret_state[SECRET_KEYS.FALAI]); + + const result = await fetch('/api/sd/falai/models', { + method: 'POST', + headers: getRequestHeaders(), + }); + + if (result.ok) { + return await result.json(); + } + + return []; +} + async function loadPollinationsModels() { const result = await fetch('/api/sd/pollinations/models', { method: 'POST', @@ -2081,6 +2104,9 @@ async function loadSchedulers() { case sources.bfl: schedulers = ['N/A']; break; + case sources.falai: + schedulers = ['N/A']; + break; } for (const scheduler of schedulers) { @@ -2735,6 +2761,9 @@ async function sendGenerationRequest(generationType, prompt, additionalNegativeP case sources.bfl: result = await generateBflImage(prefixedPrompt, signal); break; + case sources.falai: + result = await generateFalaiImage(prefixedPrompt, negativePrompt, signal); + break; } if (!result.data) { @@ -3496,6 +3525,39 @@ async function generateBflImage(prompt, signal) { } } +/** + * Generates an image using the FAL.AI API. + * @param {string} prompt - The main instruction used to guide the image generation. + * @param {string} negativePrompt - The negative prompt used to guide the image generation. + * @param {AbortSignal} signal - An AbortSignal object that can be used to cancel the request. + * @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete. + */ +async function generateFalaiImage(prompt, negativePrompt, signal) { + const result = await fetch('/api/sd/falai/generate', { + method: 'POST', + headers: getRequestHeaders(), + signal: signal, + body: JSON.stringify({ + prompt: prompt, + negative_prompt: negativePrompt, + model: extension_settings.sd.model, + steps: clamp(extension_settings.sd.steps, 1, 50), + guidance: clamp(extension_settings.sd.scale, 1.5, 5), + width: clamp(extension_settings.sd.width, 256, 1440), + height: clamp(extension_settings.sd.height, 256, 1440), + seed: extension_settings.sd.seed >= 0 ? extension_settings.sd.seed : undefined, + }), + }); + + if (result.ok) { + const data = await result.json(); + return { format: 'jpg', data: data.image }; + } else { + const text = await result.text(); + throw new Error(text); + } +} + async function onComfyOpenWorkflowEditorClick() { let workflow = await (await fetch('/api/sd/comfy/workflow', { method: 'POST', @@ -3782,6 +3844,8 @@ function isValidState() { return secret_state[SECRET_KEYS.NANOGPT]; case sources.bfl: return secret_state[SECRET_KEYS.BFL]; + case sources.falai: + return secret_state[SECRET_KEYS.FALAI]; } } @@ -4443,6 +4507,7 @@ jQuery(async () => { $('#sd_function_tool').on('input', onFunctionToolInput); $('#sd_bfl_key').on('click', onBflKeyClick); $('#sd_bfl_upsampling').on('input', onBflUpsamplingInput); + $('#sd_falai_key').on('click', onFalaiKeyClick); if (!CSS.supports('field-sizing', 'content')) { $('.sd_settings .inline-drawer-toggle').on('click', function () { diff --git a/public/scripts/extensions/stable-diffusion/settings.html b/public/scripts/extensions/stable-diffusion/settings.html index 7969ef7e2..7f766d8cd 100644 --- a/public/scripts/extensions/stable-diffusion/settings.html +++ b/public/scripts/extensions/stable-diffusion/settings.html @@ -52,6 +52,7 @@ +
@@ -256,6 +257,20 @@
+
+
+ + API Key + + + + +
+
+
diff --git a/public/scripts/secrets.js b/public/scripts/secrets.js index f2e27f3e2..261c31480 100644 --- a/public/scripts/secrets.js +++ b/public/scripts/secrets.js @@ -41,6 +41,7 @@ export const SECRET_KEYS = { GENERIC: 'api_key_generic', DEEPSEEK: 'api_key_deepseek', SERPER: 'api_key_serper', + FALAI: 'api_key_falai', }; const INPUT_MAP = { diff --git a/src/endpoints/secrets.js b/src/endpoints/secrets.js index 5683da4d8..a7f0094ba 100644 --- a/src/endpoints/secrets.js +++ b/src/endpoints/secrets.js @@ -50,6 +50,7 @@ export const SECRET_KEYS = { TAVILY: 'api_key_tavily', NANOGPT: 'api_key_nanogpt', BFL: 'api_key_bfl', + FALAI: 'api_key_falai', GENERIC: 'api_key_generic', DEEPSEEK: 'api_key_deepseek', SERPER: 'api_key_serper', diff --git a/src/endpoints/stable-diffusion.js b/src/endpoints/stable-diffusion.js index 2d611dc6b..2c23bab98 100644 --- a/src/endpoints/stable-diffusion.js +++ b/src/endpoints/stable-diffusion.js @@ -1228,6 +1228,125 @@ bfl.post('/generate', jsonParser, async (request, response) => { } }); +const falai = express.Router(); + +falai.post('/models', jsonParser, async (_request, response) => { + try { + const modelsUrl = new URL('https://fal.ai/api/models?categories=text-to-image'); + const result = await fetch(modelsUrl); + + if (!result.ok) { + console.warn('FAL.AI returned an error.', result.status, result.statusText); + throw new Error('FAL.AI request failed.'); + } + + const data = await result.json(); + + if (!Array.isArray(data)) { + console.warn('FAL.AI returned invalid data.'); + throw new Error('FAL.AI request failed.'); + } + + const models = data + .filter(x => !x.title.toLowerCase().includes('inpainting') && + !x.title.toLowerCase().includes('control') && + !x.title.toLowerCase().includes('upscale')) + .map(x => ({ value: x.modelUrl.split('fal-ai/')[1], text: x.title })); + return response.send(models); + } catch (error) { + console.error(error); + return response.sendStatus(500); + } +}); + +falai.post('/generate', jsonParser, async (request, response) => { + try { + const key = readSecret(request.user.directories, SECRET_KEYS.FALAI); + + if (!key) { + console.warn('FAL.AI key not found.'); + return response.sendStatus(400); + } + + const requestBody = { + prompt: request.body.prompt, + image_size: { 'width': request.body.width, 'height': request.body.height }, + num_inference_steps: request.body.steps, + seed: request.body.seed ?? null, + guidance_scale: request.body.guidance, + enable_safety_checker: false, + }; + + console.debug('FAL.AI request:', requestBody); + + const result = await fetch(`https://queue.fal.run/fal-ai/${request.body.model}`, { + method: 'POST', + body: JSON.stringify(requestBody), + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Key ${key}`, + }, + }); + + if (!result.ok) { + console.warn('FAL.AI returned an error.'); + return response.sendStatus(500); + } + + /** @type {any} */ + const taskData = await result.json(); + const { status_url } = taskData; + + const MAX_ATTEMPTS = 100; + for (let i = 0; i < MAX_ATTEMPTS; i++) { + await delay(2500); + + const statusResult = await fetch(status_url, { + headers: { + 'Authorization': `Key ${key}`, + }, + }); + + if (!statusResult.ok) { + const text = await statusResult.text(); + console.warn('FAL.AI returned an error.', text); + return response.sendStatus(500); + } + + /** @type {any} */ + const statusData = await statusResult.json(); + + if (statusData?.status === 'IN_QUEUE' || statusData?.status === 'IN_PROGRESS') { + continue; + } + + if (statusData?.status === 'COMPLETED') { + const resultFetch = await fetch(statusData?.response_url, { + method: 'GET', + headers: { + 'Authorization': `Key ${key}`, + }, + }); + const resultData = await resultFetch.json(); + const imageFetch = await fetch(resultData?.images[0].url, { + headers: { + 'Authorization': `Key ${key}`, + }, + }); + + const fetchData = await imageFetch.arrayBuffer(); + const image = Buffer.from(fetchData).toString('base64'); + return response.send({ image: image }); + } + + throw new Error('FAL.AI failed to generate image.', { cause: statusData }); + } + } catch (error) { + console.error(error); + return response.sendStatus(500); + } +}); + router.use('/comfy', comfy); router.use('/together', together); router.use('/drawthings', drawthings); @@ -1237,3 +1356,4 @@ router.use('/blockentropy', blockentropy); router.use('/huggingface', huggingface); router.use('/nanogpt', nanogpt); router.use('/bfl', bfl); +router.use('/falai', falai); From b033b98532b690ac7cc96844fc3ec229be3dde76 Mon Sep 17 00:00:00 2001 From: Kristan Schlikow Date: Thu, 13 Feb 2025 21:09:13 +0100 Subject: [PATCH 2/3] Address issues raised in PR --- public/scripts/extensions/stable-diffusion/index.js | 1 + public/scripts/extensions/stable-diffusion/settings.html | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/public/scripts/extensions/stable-diffusion/index.js b/public/scripts/extensions/stable-diffusion/index.js index 6d410a309..74378396e 100644 --- a/public/scripts/extensions/stable-diffusion/index.js +++ b/public/scripts/extensions/stable-diffusion/index.js @@ -1304,6 +1304,7 @@ async function onModelChange() { sources.huggingface, sources.nanogpt, sources.bfl, + sources.falai, ]; if (cloudSources.includes(extension_settings.sd.source)) { diff --git a/public/scripts/extensions/stable-diffusion/settings.html b/public/scripts/extensions/stable-diffusion/settings.html index 7f766d8cd..32ecfe28f 100644 --- a/public/scripts/extensions/stable-diffusion/settings.html +++ b/public/scripts/extensions/stable-diffusion/settings.html @@ -42,6 +42,7 @@ + @@ -52,7 +53,6 @@ -
From 76becb43ae0fbbf4842fbc1ea5104ac851de5d7c Mon Sep 17 00:00:00 2001 From: Kristan Schlikow Date: Thu, 13 Feb 2025 21:43:08 +0100 Subject: [PATCH 3/3] Pass through errors coming from FAL to the user --- public/scripts/extensions/stable-diffusion/index.js | 1 + src/endpoints/stable-diffusion.js | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/public/scripts/extensions/stable-diffusion/index.js b/public/scripts/extensions/stable-diffusion/index.js index 74378396e..58429d4fd 100644 --- a/public/scripts/extensions/stable-diffusion/index.js +++ b/public/scripts/extensions/stable-diffusion/index.js @@ -3555,6 +3555,7 @@ async function generateFalaiImage(prompt, negativePrompt, signal) { return { format: 'jpg', data: data.image }; } else { const text = await result.text(); + console.log(text); throw new Error(text); } } diff --git a/src/endpoints/stable-diffusion.js b/src/endpoints/stable-diffusion.js index 2c23bab98..dd6dc9adf 100644 --- a/src/endpoints/stable-diffusion.js +++ b/src/endpoints/stable-diffusion.js @@ -1251,6 +1251,7 @@ falai.post('/models', jsonParser, async (_request, response) => { .filter(x => !x.title.toLowerCase().includes('inpainting') && !x.title.toLowerCase().includes('control') && !x.title.toLowerCase().includes('upscale')) + .sort((a, b) => a.title.localeCompare(b.title)) .map(x => ({ value: x.modelUrl.split('fal-ai/')[1], text: x.title })); return response.send(models); } catch (error) { @@ -1328,6 +1329,11 @@ falai.post('/generate', jsonParser, async (request, response) => { }, }); const resultData = await resultFetch.json(); + + if (resultData.detail !== null && resultData.detail !== undefined) { + throw new Error('FAL.AI failed to generate image.', { cause: `${resultData.detail[0].loc[1]}: ${resultData.detail[0].msg}` }); + } + const imageFetch = await fetch(resultData?.images[0].url, { headers: { 'Authorization': `Key ${key}`, @@ -1343,7 +1349,7 @@ falai.post('/generate', jsonParser, async (request, response) => { } } catch (error) { console.error(error); - return response.sendStatus(500); + return response.status(500).send(error.cause || error.message); } });