diff --git a/public/scripts/extensions/stable-diffusion/index.js b/public/scripts/extensions/stable-diffusion/index.js index 9f9960c7d..58429d4fd 100644 --- a/public/scripts/extensions/stable-diffusion/index.js +++ b/public/scripts/extensions/stable-diffusion/index.js @@ -81,6 +81,7 @@ const sources = { huggingface: 'huggingface', nanogpt: 'nanogpt', bfl: 'bfl', + falai: 'falai', }; const initiators = { @@ -1169,6 +1170,10 @@ async function onBflKeyClick() { return onApiKeyClick('BFL API Key:', SECRET_KEYS.BFL); } +async function onFalaiKeyClick() { + return onApiKeyClick('FALAI API Key:', SECRET_KEYS.FALAI); +} + function onBflUpsamplingInput() { extension_settings.sd.bfl_upsampling = !!$('#sd_bfl_upsampling').prop('checked'); saveSettingsDebounced(); @@ -1299,6 +1304,7 @@ async function onModelChange() { sources.huggingface, sources.nanogpt, sources.bfl, + sources.falai, ]; if (cloudSources.includes(extension_settings.sd.source)) { @@ -1707,6 +1713,9 @@ async function loadModels() { case sources.bfl: models = await loadBflModels(); break; + case sources.falai: + models = await loadFalaiModels(); + break; } for (const model of models) { @@ -1744,6 +1753,21 @@ async function loadBflModels() { ]; } +async function loadFalaiModels() { + $('#sd_falai_key').toggleClass('success', !!secret_state[SECRET_KEYS.FALAI]); + + const result = await fetch('/api/sd/falai/models', { + method: 'POST', + headers: getRequestHeaders(), + }); + + if (result.ok) { + return await result.json(); + } + + return []; +} + async function loadPollinationsModels() { const result = await fetch('/api/sd/pollinations/models', { method: 'POST', @@ -2081,6 +2105,9 @@ async function loadSchedulers() { case sources.bfl: schedulers = ['N/A']; break; + case sources.falai: + schedulers = ['N/A']; + break; } for (const scheduler of schedulers) { @@ -2735,6 +2762,9 @@ async function sendGenerationRequest(generationType, prompt, additionalNegativeP case sources.bfl: result = await generateBflImage(prefixedPrompt, signal); break; + case sources.falai: + result = await generateFalaiImage(prefixedPrompt, negativePrompt, signal); + break; } if (!result.data) { @@ -3496,6 +3526,40 @@ async function generateBflImage(prompt, signal) { } } +/** + * Generates an image using the FAL.AI API. + * @param {string} prompt - The main instruction used to guide the image generation. + * @param {string} negativePrompt - The negative prompt used to guide the image generation. + * @param {AbortSignal} signal - An AbortSignal object that can be used to cancel the request. + * @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete. + */ +async function generateFalaiImage(prompt, negativePrompt, signal) { + const result = await fetch('/api/sd/falai/generate', { + method: 'POST', + headers: getRequestHeaders(), + signal: signal, + body: JSON.stringify({ + prompt: prompt, + negative_prompt: negativePrompt, + model: extension_settings.sd.model, + steps: clamp(extension_settings.sd.steps, 1, 50), + guidance: clamp(extension_settings.sd.scale, 1.5, 5), + width: clamp(extension_settings.sd.width, 256, 1440), + height: clamp(extension_settings.sd.height, 256, 1440), + seed: extension_settings.sd.seed >= 0 ? extension_settings.sd.seed : undefined, + }), + }); + + if (result.ok) { + const data = await result.json(); + return { format: 'jpg', data: data.image }; + } else { + const text = await result.text(); + console.log(text); + throw new Error(text); + } +} + async function onComfyOpenWorkflowEditorClick() { let workflow = await (await fetch('/api/sd/comfy/workflow', { method: 'POST', @@ -3782,6 +3846,8 @@ function isValidState() { return secret_state[SECRET_KEYS.NANOGPT]; case sources.bfl: return secret_state[SECRET_KEYS.BFL]; + case sources.falai: + return secret_state[SECRET_KEYS.FALAI]; } } @@ -4443,6 +4509,7 @@ jQuery(async () => { $('#sd_function_tool').on('input', onFunctionToolInput); $('#sd_bfl_key').on('click', onBflKeyClick); $('#sd_bfl_upsampling').on('input', onBflUpsamplingInput); + $('#sd_falai_key').on('click', onFalaiKeyClick); if (!CSS.supports('field-sizing', 'content')) { $('.sd_settings .inline-drawer-toggle').on('click', function () { diff --git a/public/scripts/extensions/stable-diffusion/settings.html b/public/scripts/extensions/stable-diffusion/settings.html index 7969ef7e2..32ecfe28f 100644 --- a/public/scripts/extensions/stable-diffusion/settings.html +++ b/public/scripts/extensions/stable-diffusion/settings.html @@ -42,6 +42,7 @@ + @@ -256,6 +257,20 @@ +
+
+ + API Key + + + + +
+
+
diff --git a/public/scripts/secrets.js b/public/scripts/secrets.js index f2e27f3e2..261c31480 100644 --- a/public/scripts/secrets.js +++ b/public/scripts/secrets.js @@ -41,6 +41,7 @@ export const SECRET_KEYS = { GENERIC: 'api_key_generic', DEEPSEEK: 'api_key_deepseek', SERPER: 'api_key_serper', + FALAI: 'api_key_falai', }; const INPUT_MAP = { diff --git a/src/endpoints/secrets.js b/src/endpoints/secrets.js index 5683da4d8..a7f0094ba 100644 --- a/src/endpoints/secrets.js +++ b/src/endpoints/secrets.js @@ -50,6 +50,7 @@ export const SECRET_KEYS = { TAVILY: 'api_key_tavily', NANOGPT: 'api_key_nanogpt', BFL: 'api_key_bfl', + FALAI: 'api_key_falai', GENERIC: 'api_key_generic', DEEPSEEK: 'api_key_deepseek', SERPER: 'api_key_serper', diff --git a/src/endpoints/stable-diffusion.js b/src/endpoints/stable-diffusion.js index 2d611dc6b..dd6dc9adf 100644 --- a/src/endpoints/stable-diffusion.js +++ b/src/endpoints/stable-diffusion.js @@ -1228,6 +1228,131 @@ bfl.post('/generate', jsonParser, async (request, response) => { } }); +const falai = express.Router(); + +falai.post('/models', jsonParser, async (_request, response) => { + try { + const modelsUrl = new URL('https://fal.ai/api/models?categories=text-to-image'); + const result = await fetch(modelsUrl); + + if (!result.ok) { + console.warn('FAL.AI returned an error.', result.status, result.statusText); + throw new Error('FAL.AI request failed.'); + } + + const data = await result.json(); + + if (!Array.isArray(data)) { + console.warn('FAL.AI returned invalid data.'); + throw new Error('FAL.AI request failed.'); + } + + const models = data + .filter(x => !x.title.toLowerCase().includes('inpainting') && + !x.title.toLowerCase().includes('control') && + !x.title.toLowerCase().includes('upscale')) + .sort((a, b) => a.title.localeCompare(b.title)) + .map(x => ({ value: x.modelUrl.split('fal-ai/')[1], text: x.title })); + return response.send(models); + } catch (error) { + console.error(error); + return response.sendStatus(500); + } +}); + +falai.post('/generate', jsonParser, async (request, response) => { + try { + const key = readSecret(request.user.directories, SECRET_KEYS.FALAI); + + if (!key) { + console.warn('FAL.AI key not found.'); + return response.sendStatus(400); + } + + const requestBody = { + prompt: request.body.prompt, + image_size: { 'width': request.body.width, 'height': request.body.height }, + num_inference_steps: request.body.steps, + seed: request.body.seed ?? null, + guidance_scale: request.body.guidance, + enable_safety_checker: false, + }; + + console.debug('FAL.AI request:', requestBody); + + const result = await fetch(`https://queue.fal.run/fal-ai/${request.body.model}`, { + method: 'POST', + body: JSON.stringify(requestBody), + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Key ${key}`, + }, + }); + + if (!result.ok) { + console.warn('FAL.AI returned an error.'); + return response.sendStatus(500); + } + + /** @type {any} */ + const taskData = await result.json(); + const { status_url } = taskData; + + const MAX_ATTEMPTS = 100; + for (let i = 0; i < MAX_ATTEMPTS; i++) { + await delay(2500); + + const statusResult = await fetch(status_url, { + headers: { + 'Authorization': `Key ${key}`, + }, + }); + + if (!statusResult.ok) { + const text = await statusResult.text(); + console.warn('FAL.AI returned an error.', text); + return response.sendStatus(500); + } + + /** @type {any} */ + const statusData = await statusResult.json(); + + if (statusData?.status === 'IN_QUEUE' || statusData?.status === 'IN_PROGRESS') { + continue; + } + + if (statusData?.status === 'COMPLETED') { + const resultFetch = await fetch(statusData?.response_url, { + method: 'GET', + headers: { + 'Authorization': `Key ${key}`, + }, + }); + const resultData = await resultFetch.json(); + + if (resultData.detail !== null && resultData.detail !== undefined) { + throw new Error('FAL.AI failed to generate image.', { cause: `${resultData.detail[0].loc[1]}: ${resultData.detail[0].msg}` }); + } + + const imageFetch = await fetch(resultData?.images[0].url, { + headers: { + 'Authorization': `Key ${key}`, + }, + }); + + const fetchData = await imageFetch.arrayBuffer(); + const image = Buffer.from(fetchData).toString('base64'); + return response.send({ image: image }); + } + + throw new Error('FAL.AI failed to generate image.', { cause: statusData }); + } + } catch (error) { + console.error(error); + return response.status(500).send(error.cause || error.message); + } +}); + router.use('/comfy', comfy); router.use('/together', together); router.use('/drawthings', drawthings); @@ -1237,3 +1362,4 @@ router.use('/blockentropy', blockentropy); router.use('/huggingface', huggingface); router.use('/nanogpt', nanogpt); router.use('/bfl', bfl); +router.use('/falai', falai);