Add xAI for image generation extension

This commit is contained in:
Cohee
2025-04-11 20:32:06 +03:00
parent c1544fb60c
commit 0c4c86ef06
3 changed files with 108 additions and 0 deletions

View File

@ -81,6 +81,7 @@ const sources = {
nanogpt: 'nanogpt',
bfl: 'bfl',
falai: 'falai',
xai: 'xai',
};
const initiators = {
@ -1303,6 +1304,7 @@ async function onModelChange() {
sources.nanogpt,
sources.bfl,
sources.falai,
sources.xai,
];
if (cloudSources.includes(extension_settings.sd.source)) {
@ -1518,6 +1520,9 @@ async function loadSamplers() {
case sources.bfl:
samplers = ['N/A'];
break;
case sources.xai:
samplers = ['N/A'];
break;
}
for (const sampler of samplers) {
@ -1708,6 +1713,9 @@ async function loadModels() {
case sources.falai:
models = await loadFalaiModels();
break;
case sources.xai:
models = await loadXAIModels();
break;
}
for (const model of models) {
@ -1760,6 +1768,12 @@ async function loadFalaiModels() {
return [];
}
async function loadXAIModels() {
return [
{ value: 'grok-2-image-1212', text: 'grok-2-image-1212' },
];
}
async function loadPollinationsModels() {
const result = await fetch('/api/sd/pollinations/models', {
method: 'POST',
@ -2081,6 +2095,9 @@ async function loadSchedulers() {
case sources.falai:
schedulers = ['N/A'];
break;
case sources.xai:
schedulers = ['N/A'];
break;
}
for (const scheduler of schedulers) {
@ -2166,6 +2183,12 @@ async function loadVaes() {
case sources.bfl:
vaes = ['N/A'];
break;
case sources.falai:
vaes = ['N/A'];
break;
case sources.xai:
vaes = ['N/A'];
break;
}
for (const vae of vaes) {
@ -2735,6 +2758,9 @@ async function sendGenerationRequest(generationType, prompt, additionalNegativeP
case sources.falai:
result = await generateFalaiImage(prefixedPrompt, negativePrompt, signal);
break;
case sources.xai:
result = await generateXAIImage(prefixedPrompt, negativePrompt, signal);
break;
}
if (!result.data) {
@ -3463,6 +3489,33 @@ async function generateBflImage(prompt, signal) {
}
}
/**
* Generates an image using the xAI API.
* @param {string} prompt The main instruction used to guide the image generation.
* @param {string} _negativePrompt Negative prompt is not used in this API
* @param {AbortSignal} signal An AbortSignal object that can be used to cancel the request.
* @returns {Promise<{format: string, data: string}>} A promise that resolves when the image generation and processing are complete.
*/
async function generateXAIImage(prompt, _negativePrompt, signal) {
const result = await fetch('/api/sd/xai/generate', {
method: 'POST',
headers: getRequestHeaders(),
signal: signal,
body: JSON.stringify({
prompt: prompt,
model: extension_settings.sd.model,
}),
});
if (result.ok) {
const data = await result.json();
return { format: 'jpg', data: data.image };
} else {
const text = await result.text();
throw new Error(text);
}
}
/**
* Generates an image using the FAL.AI API.
* @param {string} prompt - The main instruction used to guide the image generation.
@ -3782,6 +3835,8 @@ function isValidState() {
return secret_state[SECRET_KEYS.BFL];
case sources.falai:
return secret_state[SECRET_KEYS.FALAI];
case sources.xai:
return secret_state[SECRET_KEYS.XAI];
}
}

View File

@ -52,6 +52,7 @@
<option value="auto">Stable Diffusion Web UI (AUTOMATIC1111)</option>
<option value="horde">Stable Horde</option>
<option value="togetherai">TogetherAI</option>
<option value="xai">xAI (Grok)</option>
</select>
<div data-sd-source="auto">
<label for="sd_auto_url">SD Web UI URL</label>

View File

@ -1245,6 +1245,7 @@ falai.post('/generate', async (request, response) => {
'Authorization': `Key ${key}`,
},
});
/** @type {any} */
const resultData = await resultFetch.json();
if (resultData.detail !== null && resultData.detail !== undefined) {
@ -1270,6 +1271,56 @@ falai.post('/generate', async (request, response) => {
}
});
const xai = express.Router();
xai.post('/generate', async (request, response) => {
try {
const key = readSecret(request.user.directories, SECRET_KEYS.XAI);
if (!key) {
console.warn('xAI key not found.');
return response.sendStatus(400);
}
const requestBody = {
prompt: request.body.prompt,
model: request.body.model,
response_format: 'b64_json',
};
console.debug('xAI request:', requestBody);
const result = await fetch('https://api.x.ai/v1/images/generations', {
method: 'POST',
body: JSON.stringify(requestBody),
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${key}`,
},
});
if (!result.ok) {
const text = await result.text();
console.warn('xAI returned an error.', text);
return response.sendStatus(500);
}
/** @type {any} */
const data = await result.json();
const image = data?.data?.[0]?.b64_json;
if (!image) {
console.warn('xAI returned invalid data.');
return response.sendStatus(500);
}
return response.send({ image });
} catch (error) {
console.error('Error communicating with xAI', error);
return response.sendStatus(500);
}
});
router.use('/comfy', comfy);
router.use('/together', together);
router.use('/drawthings', drawthings);
@ -1279,3 +1330,4 @@ router.use('/huggingface', huggingface);
router.use('/nanogpt', nanogpt);
router.use('/bfl', bfl);
router.use('/falai', falai);
router.use('/xai', xai);