Add cancelling of SD gens

This commit is contained in:
Cohee
2024-07-18 23:23:33 +03:00
parent 88acb568ad
commit 1effb66fd6
2 changed files with 74 additions and 25 deletions

View File

@ -2289,24 +2289,34 @@ async function generatePicture(initiator, args, trigger, message, callback) {
} }
const dimensions = setTypeSpecificDimensions(generationType); const dimensions = setTypeSpecificDimensions(generationType);
const abortController = new AbortController();
let negativePromptPrefix = args?.negative || ''; let negativePromptPrefix = args?.negative || '';
let imagePath = ''; let imagePath = '';
const stopListener = () => abortController.abort('Aborted by user');
const mesStop = document.getElementById('mes_stop');
try { try {
const combineNegatives = (prefix) => { negativePromptPrefix = combinePrefixes(negativePromptPrefix, prefix); }; const combineNegatives = (prefix) => { negativePromptPrefix = combinePrefixes(negativePromptPrefix, prefix); };
const prompt = await getPrompt(generationType, message, trigger, quietPrompt, combineNegatives); const prompt = await getPrompt(generationType, message, trigger, quietPrompt, combineNegatives);
console.log('Processed image prompt:', prompt); console.log('Processed image prompt:', prompt);
mesStop?.addEventListener('click', stopListener);
context.deactivateSendButtons(); context.deactivateSendButtons();
hideSwipeButtons(); hideSwipeButtons();
imagePath = await sendGenerationRequest(generationType, prompt, negativePromptPrefix, characterName, callback, initiator); if (typeof args?._abortController?.addEventListener === 'function') {
args._abortController.addEventListener('abort', stopListener);
}
imagePath = await sendGenerationRequest(generationType, prompt, negativePromptPrefix, characterName, callback, initiator, abortController.signal);
} catch (err) { } catch (err) {
console.trace(err); console.trace(err);
throw new Error('SD prompt text generation failed.'); throw new Error('SD prompt text generation failed.');
} }
finally { finally {
restoreOriginalDimensions(dimensions); restoreOriginalDimensions(dimensions);
mesStop?.removeEventListener('click', stopListener);
context.activateSendButtons(); context.activateSendButtons();
showSwipeButtons(); showSwipeButtons();
} }
@ -2521,9 +2531,10 @@ async function generatePrompt(quietPrompt) {
* @param {string} characterName Name of the character * @param {string} characterName Name of the character
* @param {function} callback Callback function to be called after image generation * @param {function} callback Callback function to be called after image generation
* @param {string} initiator The initiator of the image generation * @param {string} initiator The initiator of the image generation
* @param {AbortSignal} signal Abort signal to cancel the request
* @returns * @returns
*/ */
async function sendGenerationRequest(generationType, prompt, additionalNegativePrefix, characterName, callback, initiator) { async function sendGenerationRequest(generationType, prompt, additionalNegativePrefix, characterName, callback, initiator, signal) {
const noCharPrefix = [generationMode.FREE, generationMode.BACKGROUND, generationMode.USER, generationMode.USER_MULTIMODAL, generationMode.FREE_EXTENDED]; const noCharPrefix = [generationMode.FREE, generationMode.BACKGROUND, generationMode.USER, generationMode.USER_MULTIMODAL, generationMode.FREE_EXTENDED];
const prefix = noCharPrefix.includes(generationType) const prefix = noCharPrefix.includes(generationType)
? extension_settings.sd.prompt_prefix ? extension_settings.sd.prompt_prefix
@ -2541,37 +2552,37 @@ async function sendGenerationRequest(generationType, prompt, additionalNegativeP
try { try {
switch (extension_settings.sd.source) { switch (extension_settings.sd.source) {
case sources.extras: case sources.extras:
result = await generateExtrasImage(prefixedPrompt, negativePrompt); result = await generateExtrasImage(prefixedPrompt, negativePrompt, signal);
break; break;
case sources.horde: case sources.horde:
result = await generateHordeImage(prefixedPrompt, negativePrompt); result = await generateHordeImage(prefixedPrompt, negativePrompt, signal);
break; break;
case sources.vlad: case sources.vlad:
result = await generateAutoImage(prefixedPrompt, negativePrompt); result = await generateAutoImage(prefixedPrompt, negativePrompt, signal);
break; break;
case sources.drawthings: case sources.drawthings:
result = await generateDrawthingsImage(prefixedPrompt, negativePrompt); result = await generateDrawthingsImage(prefixedPrompt, negativePrompt, signal);
break; break;
case sources.auto: case sources.auto:
result = await generateAutoImage(prefixedPrompt, negativePrompt); result = await generateAutoImage(prefixedPrompt, negativePrompt, signal);
break; break;
case sources.novel: case sources.novel:
result = await generateNovelImage(prefixedPrompt, negativePrompt); result = await generateNovelImage(prefixedPrompt, negativePrompt, signal);
break; break;
case sources.openai: case sources.openai:
result = await generateOpenAiImage(prefixedPrompt); result = await generateOpenAiImage(prefixedPrompt, signal);
break; break;
case sources.comfy: case sources.comfy:
result = await generateComfyImage(prefixedPrompt, negativePrompt); result = await generateComfyImage(prefixedPrompt, negativePrompt, signal);
break; break;
case sources.togetherai: case sources.togetherai:
result = await generateTogetherAIImage(prefixedPrompt, negativePrompt); result = await generateTogetherAIImage(prefixedPrompt, negativePrompt, signal);
break; break;
case sources.pollinations: case sources.pollinations:
result = await generatePollinationsImage(prefixedPrompt, negativePrompt); result = await generatePollinationsImage(prefixedPrompt, negativePrompt, signal);
break; break;
case sources.stability: case sources.stability:
result = await generateStabilityImage(prefixedPrompt, negativePrompt); result = await generateStabilityImage(prefixedPrompt, negativePrompt, signal);
break; break;
} }
@ -2600,12 +2611,14 @@ async function sendGenerationRequest(generationType, prompt, additionalNegativeP
* Generates an image using the TogetherAI API. * Generates an image using the TogetherAI API.
* @param {string} prompt - The main instruction used to guide the image generation. * @param {string} prompt - The main instruction used to guide the image generation.
* @param {string} negativePrompt - The instruction used to restrict the image generation. * @param {string} negativePrompt - The instruction used to restrict the image generation.
* @param {AbortSignal} signal - An AbortSignal object that can be used to cancel the request.
* @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete. * @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete.
*/ */
async function generateTogetherAIImage(prompt, negativePrompt) { async function generateTogetherAIImage(prompt, negativePrompt, signal) {
const result = await fetch('/api/sd/together/generate', { const result = await fetch('/api/sd/together/generate', {
method: 'POST', method: 'POST',
headers: getRequestHeaders(), headers: getRequestHeaders(),
signal: signal,
body: JSON.stringify({ body: JSON.stringify({
prompt: prompt, prompt: prompt,
negative_prompt: negativePrompt, negative_prompt: negativePrompt,
@ -2630,12 +2643,14 @@ async function generateTogetherAIImage(prompt, negativePrompt) {
* Generates an image using the Pollinations API. * Generates an image using the Pollinations API.
* @param {string} prompt - The main instruction used to guide the image generation. * @param {string} prompt - The main instruction used to guide the image generation.
* @param {string} negativePrompt - The instruction used to restrict the image generation. * @param {string} negativePrompt - The instruction used to restrict the image generation.
* @param {AbortSignal} signal - An AbortSignal object that can be used to cancel the request.
* @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete. * @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete.
*/ */
async function generatePollinationsImage(prompt, negativePrompt) { async function generatePollinationsImage(prompt, negativePrompt, signal) {
const result = await fetch('/api/sd/pollinations/generate', { const result = await fetch('/api/sd/pollinations/generate', {
method: 'POST', method: 'POST',
headers: getRequestHeaders(), headers: getRequestHeaders(),
signal: signal,
body: JSON.stringify({ body: JSON.stringify({
prompt: prompt, prompt: prompt,
negative_prompt: negativePrompt, negative_prompt: negativePrompt,
@ -2662,9 +2677,10 @@ async function generatePollinationsImage(prompt, negativePrompt) {
* *
* @param {string} prompt - The main instruction used to guide the image generation. * @param {string} prompt - The main instruction used to guide the image generation.
* @param {string} negativePrompt - The instruction used to restrict the image generation. * @param {string} negativePrompt - The instruction used to restrict the image generation.
* @param {AbortSignal} signal - An AbortSignal object that can be used to cancel the request.
* @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete. * @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete.
*/ */
async function generateExtrasImage(prompt, negativePrompt) { async function generateExtrasImage(prompt, negativePrompt, signal) {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/image'; url.pathname = '/api/image';
const result = await doExtrasFetch(url, { const result = await doExtrasFetch(url, {
@ -2672,6 +2688,7 @@ async function generateExtrasImage(prompt, negativePrompt) {
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
}, },
signal: signal,
body: JSON.stringify({ body: JSON.stringify({
prompt: prompt, prompt: prompt,
sampler: extension_settings.sd.sampler, sampler: extension_settings.sd.sampler,
@ -2739,9 +2756,10 @@ function getClosestAspectRatio(width, height) {
* Generates an image using Stability AI. * Generates an image using Stability AI.
* @param {string} prompt - The main instruction used to guide the image generation. * @param {string} prompt - The main instruction used to guide the image generation.
* @param {string} negativePrompt - The instruction used to restrict the image generation. * @param {string} negativePrompt - The instruction used to restrict the image generation.
* @param {AbortSignal} signal - An AbortSignal object that can be used to cancel the request.
* @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete. * @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete.
*/ */
async function generateStabilityImage(prompt, negativePrompt) { async function generateStabilityImage(prompt, negativePrompt, signal) {
const IMAGE_FORMAT = 'png'; const IMAGE_FORMAT = 'png';
const PROMPT_LIMIT = 10000; const PROMPT_LIMIT = 10000;
@ -2749,6 +2767,7 @@ async function generateStabilityImage(prompt, negativePrompt) {
const response = await fetch('/api/sd/stability/generate', { const response = await fetch('/api/sd/stability/generate', {
method: 'POST', method: 'POST',
headers: getRequestHeaders(), headers: getRequestHeaders(),
signal: signal,
body: JSON.stringify({ body: JSON.stringify({
model: extension_settings.sd.model, model: extension_settings.sd.model,
payload: { payload: {
@ -2783,12 +2802,14 @@ async function generateStabilityImage(prompt, negativePrompt) {
* *
* @param {string} prompt - The main instruction used to guide the image generation. * @param {string} prompt - The main instruction used to guide the image generation.
* @param {string} negativePrompt - The instruction used to restrict the image generation. * @param {string} negativePrompt - The instruction used to restrict the image generation.
* @param {AbortSignal} signal - An AbortSignal object that can be used to cancel the request.
* @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete. * @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete.
*/ */
async function generateHordeImage(prompt, negativePrompt) { async function generateHordeImage(prompt, negativePrompt, signal) {
const result = await fetch('/api/horde/generate-image', { const result = await fetch('/api/horde/generate-image', {
method: 'POST', method: 'POST',
headers: getRequestHeaders(), headers: getRequestHeaders(),
signal: signal,
body: JSON.stringify({ body: JSON.stringify({
prompt: prompt, prompt: prompt,
sampler: extension_settings.sd.sampler, sampler: extension_settings.sd.sampler,
@ -2821,13 +2842,15 @@ async function generateHordeImage(prompt, negativePrompt) {
* *
* @param {string} prompt - The main instruction used to guide the image generation. * @param {string} prompt - The main instruction used to guide the image generation.
* @param {string} negativePrompt - The instruction used to restrict the image generation. * @param {string} negativePrompt - The instruction used to restrict the image generation.
* @param {AbortSignal} signal - An AbortSignal object that can be used to cancel the request.
* @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete. * @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete.
*/ */
async function generateAutoImage(prompt, negativePrompt) { async function generateAutoImage(prompt, negativePrompt, signal) {
const isValidVae = extension_settings.sd.vae && !['N/A', placeholderVae].includes(extension_settings.sd.vae); const isValidVae = extension_settings.sd.vae && !['N/A', placeholderVae].includes(extension_settings.sd.vae);
const result = await fetch('/api/sd/generate', { const result = await fetch('/api/sd/generate', {
method: 'POST', method: 'POST',
headers: getRequestHeaders(), headers: getRequestHeaders(),
signal: signal,
body: JSON.stringify({ body: JSON.stringify({
...getSdRequestBody(), ...getSdRequestBody(),
prompt: prompt, prompt: prompt,
@ -2875,12 +2898,14 @@ async function generateAutoImage(prompt, negativePrompt) {
* *
* @param {string} prompt - The main instruction used to guide the image generation. * @param {string} prompt - The main instruction used to guide the image generation.
* @param {string} negativePrompt - The instruction used to restrict the image generation. * @param {string} negativePrompt - The instruction used to restrict the image generation.
* @param {AbortSignal} signal - An AbortSignal object that can be used to cancel the request.
* @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete. * @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete.
*/ */
async function generateDrawthingsImage(prompt, negativePrompt) { async function generateDrawthingsImage(prompt, negativePrompt, signal) {
const result = await fetch('/api/sd/drawthings/generate', { const result = await fetch('/api/sd/drawthings/generate', {
method: 'POST', method: 'POST',
headers: getRequestHeaders(), headers: getRequestHeaders(),
signal: signal,
body: JSON.stringify({ body: JSON.stringify({
...getSdRequestBody(), ...getSdRequestBody(),
prompt: prompt, prompt: prompt,
@ -2914,14 +2939,16 @@ async function generateDrawthingsImage(prompt, negativePrompt) {
* *
* @param {string} prompt - The main instruction used to guide the image generation. * @param {string} prompt - The main instruction used to guide the image generation.
* @param {string} negativePrompt - The instruction used to restrict the image generation. * @param {string} negativePrompt - The instruction used to restrict the image generation.
* @param {AbortSignal} signal - An AbortSignal object that can be used to cancel the request.
* @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete. * @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete.
*/ */
async function generateNovelImage(prompt, negativePrompt) { async function generateNovelImage(prompt, negativePrompt, signal) {
const { steps, width, height, sm, sm_dyn } = getNovelParams(); const { steps, width, height, sm, sm_dyn } = getNovelParams();
const result = await fetch('/api/novelai/generate-image', { const result = await fetch('/api/novelai/generate-image', {
method: 'POST', method: 'POST',
headers: getRequestHeaders(), headers: getRequestHeaders(),
signal: signal,
body: JSON.stringify({ body: JSON.stringify({
prompt: prompt, prompt: prompt,
model: extension_settings.sd.model, model: extension_settings.sd.model,
@ -3010,7 +3037,13 @@ function getNovelParams() {
return { steps, width, height, sm, sm_dyn }; return { steps, width, height, sm, sm_dyn };
} }
async function generateOpenAiImage(prompt) { /**
* Generates an image in OpenAI API using the provided prompt and configuration settings.
* @param {string} prompt - The main instruction used to guide the image generation.
* @param {AbortSignal} signal - An AbortSignal object that can be used to cancel the request.
* @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete.
*/
async function generateOpenAiImage(prompt, signal) {
const dalle2PromptLimit = 1000; const dalle2PromptLimit = 1000;
const dalle3PromptLimit = 4000; const dalle3PromptLimit = 4000;
@ -3045,6 +3078,7 @@ async function generateOpenAiImage(prompt) {
const result = await fetch('/api/openai/generate-image', { const result = await fetch('/api/openai/generate-image', {
method: 'POST', method: 'POST',
headers: getRequestHeaders(), headers: getRequestHeaders(),
signal: signal,
body: JSON.stringify({ body: JSON.stringify({
prompt: prompt, prompt: prompt,
model: extension_settings.sd.model, model: extension_settings.sd.model,
@ -3070,9 +3104,10 @@ async function generateOpenAiImage(prompt) {
* *
* @param {string} prompt - The main instruction used to guide the image generation. * @param {string} prompt - The main instruction used to guide the image generation.
* @param {string} negativePrompt - The instruction used to restrict the image generation. * @param {string} negativePrompt - The instruction used to restrict the image generation.
* @param {AbortSignal} signal - An AbortSignal object that can be used to cancel the request.
* @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete. * @returns {Promise<{format: string, data: string}>} - A promise that resolves when the image generation and processing are complete.
*/ */
async function generateComfyImage(prompt, negativePrompt) { async function generateComfyImage(prompt, negativePrompt, signal) {
const placeholders = [ const placeholders = [
'model', 'model',
'vae', 'vae',
@ -3133,6 +3168,7 @@ async function generateComfyImage(prompt, negativePrompt) {
const promptResult = await fetch('/api/sd/comfy/generate', { const promptResult = await fetch('/api/sd/comfy/generate', {
method: 'POST', method: 'POST',
headers: getRequestHeaders(), headers: getRequestHeaders(),
signal: signal,
body: JSON.stringify({ body: JSON.stringify({
url: extension_settings.sd.comfy_url, url: extension_settings.sd.comfy_url,
prompt: `{ prompt: `{
@ -3245,7 +3281,7 @@ async function onComfyNewWorkflowClick() {
if (!name) { if (!name) {
return; return;
} }
if (!name.toLowerCase().endsWith('.json')) { if (!String(name).toLowerCase().endsWith('.json')) {
name += '.json'; name += '.json';
} }
extension_settings.sd.comfy_workflow = name; extension_settings.sd.comfy_workflow = name;
@ -3448,8 +3484,10 @@ async function sdMessageButton(e) {
const messageText = message?.mes; const messageText = message?.mes;
const hasSavedImage = message?.extra?.image && message?.extra?.title; const hasSavedImage = message?.extra?.image && message?.extra?.title;
const hasSavedNegative = message?.extra?.negative; const hasSavedNegative = message?.extra?.negative;
const abortController = new AbortController();
if ($icon.hasClass(busyClass)) { if ($icon.hasClass(busyClass)) {
abortController.abort();
console.log('Previous image is still being generated...'); console.log('Previous image is still being generated...');
return; return;
} }
@ -3466,7 +3504,7 @@ async function sdMessageButton(e) {
const generationType = message?.extra?.generationType ?? generationMode.FREE; const generationType = message?.extra?.generationType ?? generationMode.FREE;
console.log('Regenerating an image, using existing prompt:', prompt); console.log('Regenerating an image, using existing prompt:', prompt);
dimensions = setTypeSpecificDimensions(generationType); dimensions = setTypeSpecificDimensions(generationType);
await sendGenerationRequest(generationType, prompt, negative, characterFileName, saveGeneratedImage, initiators.action); await sendGenerationRequest(generationType, prompt, negative, characterFileName, saveGeneratedImage, initiators.action, abortController.signal);
} }
else { else {
console.log('doing /sd raw last'); console.log('doing /sd raw last');

View File

@ -339,7 +339,18 @@ router.post('/generate-image', jsonParser, async (request, response) => {
return response.sendStatus(400); return response.sendStatus(400);
} }
console.log('Horde image generation request:', generation);
const controller = new AbortController();
request.socket.removeAllListeners('close');
request.socket.on('close', function () {
console.log('Horde image generation request aborted.');
controller.abort();
if (generation.id) ai_horde.deleteImageGenerationRequest(generation.id);
});
for (let attempt = 0; attempt < MAX_ATTEMPTS; attempt++) { for (let attempt = 0; attempt < MAX_ATTEMPTS; attempt++) {
controller.signal.throwIfAborted();
await delay(CHECK_INTERVAL); await delay(CHECK_INTERVAL);
const check = await ai_horde.getImageGenerationCheck(generation.id); const check = await ai_horde.getImageGenerationCheck(generation.id);
console.log(check); console.log(check);