Add SD prompt expansion

This commit is contained in:
Cohee 2023-10-20 15:03:26 +03:00
parent 5c6343e85e
commit c4e6b565a5
5 changed files with 129 additions and 1 deletions

View File

@ -26,6 +26,8 @@ const extras = {
captioningModel: 'Xenova/vit-gpt2-image-captioning',
// Feature extraction model. HuggingFace ID of a model in ONNX format.
embeddingModel: 'Xenova/all-mpnet-base-v2',
// GPT-2 text generation model. HuggingFace ID of a model in ONNX format.
promptExpansionModel: 'Cohee/fooocus_expansion-onnx',
};
// Request overrides for additional headers

View File

@ -160,6 +160,7 @@ const defaultSettings = {
// Refine mode
refine_mode: false,
expand: false,
prompts: promptTemplates,
@ -257,6 +258,7 @@ async function loadSettings() {
$('#sd_restore_faces').prop('checked', extension_settings.sd.restore_faces);
$('#sd_enable_hr').prop('checked', extension_settings.sd.enable_hr);
$('#sd_refine_mode').prop('checked', extension_settings.sd.refine_mode);
$('#sd_expand').prop('checked', extension_settings.sd.expand);
$('#sd_auto_url').val(extension_settings.sd.auto_url);
$('#sd_auto_auth').val(extension_settings.sd.auto_auth);
$('#sd_vlad_url').val(extension_settings.sd.vlad_url);
@ -300,7 +302,30 @@ function addPromptTemplates() {
}
}
async function expandPrompt(prompt) {
try {
const response = await fetch('/api/sd/expand', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({ prompt: prompt }),
});
if (!response.ok) {
throw new Error('API returned an error.');
}
const data = await response.json();
return data.prompt;
} catch {
return prompt;
}
}
async function refinePrompt(prompt) {
if (extension_settings.sd.expand) {
prompt = await expandPrompt(prompt);
}
if (extension_settings.sd.refine_mode) {
const refinedPrompt = await callPopup('<h3>Review and edit the prompt:</h3>Press "Cancel" to abort the image generation.', 'input', prompt.trim(), { rows: 5, okButton: 'Generate' });
@ -361,6 +386,11 @@ function combinePrefixes(str1, str2) {
return result;
}
function onExpandInput() {
extension_settings.sd.expand = !!$(this).prop('checked');
saveSettingsDebounced();
}
function onRefineModeInput() {
extension_settings.sd.refine_mode = !!$('#sd_refine_mode').prop('checked');
saveSettingsDebounced();
@ -1610,6 +1640,7 @@ jQuery(async () => {
$('#sd_novel_upscale_ratio').on('input', onNovelUpscaleRatioInput);
$('#sd_novel_anlas_guard').on('input', onNovelAnlasGuardInput);
$('#sd_novel_view_anlas').on('click', onViewAnlasClick);
$('#sd_expand').on('input', onExpandInput);
$('#sd_character_prompt_block').hide();
$('.sd_settings .inline-drawer-toggle').on('click', function () {

View File

@ -12,6 +12,10 @@
<input id="sd_refine_mode" type="checkbox" />
Edit prompts before generation
</label>
<label for="sd_expand" class="checkbox_label" title="Automatically extend prompts using text generation model">
<input id="sd_expand" type="checkbox" />
Auto-enhance prompts
</label>
<label for="sd_source">Source</label>
<select id="sd_source">
<option value="extras">Extras API (local / remote)</option>

View File

@ -1,6 +1,43 @@
const fetch = require('node-fetch').default;
const { getBasicAuthHeader, delay } = require('./util');
/**
* Sanitizes a string.
* @param {string} x String to sanitize
* @returns {string} Sanitized string
*/
function safeStr(x) {
x = String(x);
for (let i = 0; i < 16; i++) {
x = x.replace(/ /g, ' ');
}
x = x.trim();
x = x.replace(/^[\s,.]+|[\s,.]+$/g, '');
return x;
}
const splitStrings = [
', extremely',
', intricate,',
];
const dangerousPatterns = '[]【】()|:';
/**
* Removes patterns from a string.
* @param {string} x String to sanitize
* @param {string} pattern Pattern to remove
* @returns {string} Sanitized string
*/
function removePattern(x, pattern) {
for (let i = 0; i < pattern.length; i++) {
let p = pattern[i];
let regex = new RegExp("\\" + p, 'g');
x = x.replace(regex, '');
}
return x;
}
/**
* Registers the endpoints for the Stable Diffusion API extension.
* @param {import("express").Express} app Express app
@ -275,6 +312,40 @@ function registerEndpoints(app, jsonParser) {
return response.sendStatus(500);
}
});
/**
* SD prompt expansion using GPT-2 text generation model.
* Adapted from: https://github.com/lllyasviel/Fooocus/blob/main/modules/expansion.py
*/
app.post('/api/sd/expand', jsonParser, async (request, response) => {
const originalPrompt = request.body.prompt;
if (!originalPrompt) {
console.warn('No prompt provided for SD expansion.');
return response.send({ prompt: '' });
}
console.log('Refine prompt input:', originalPrompt);
const splitString = splitStrings[Math.floor(Math.random() * splitStrings.length)];
let prompt = safeStr(originalPrompt) + splitString;
try {
const task = 'text-generation';
const module = await import('./transformers.mjs');
const pipe = await module.default.getPipeline(task);
const result = await pipe(prompt, { num_beams: 1, max_new_tokens: 256, do_sample: true });
const newText = result[0].generated_text;
const newPrompt = safeStr(removePattern(newText, dangerousPatterns));
console.log('Refine prompt output:', newPrompt);
return response.send({ prompt: newPrompt });
} catch {
console.warn('Failed to load transformers.js pipeline.');
return response.send({ prompt: originalPrompt });
}
});
}
module.exports = {

View File

@ -1,4 +1,4 @@
import { pipeline, env, RawImage } from 'sillytavern-transformers';
import { pipeline, env, RawImage, Pipeline } from 'sillytavern-transformers';
import { getConfigValue } from './util.js';
import path from 'path';
import _ from 'lodash';
@ -28,8 +28,18 @@ const tasks = {
pipeline: null,
configField: 'extras.embeddingModel',
},
'text-generation': {
defaultModel: 'Cohee/fooocus_expansion-onnx',
pipeline: null,
configField: 'extras.promptExpansionModel',
},
}
/**
* Gets a RawImage object from a base64-encoded image.
* @param {string} image Base64-encoded image
* @returns {Promise<RawImage|null>} Object representing the image
*/
async function getRawImage(image) {
try {
const buffer = Buffer.from(image, 'base64');
@ -43,6 +53,11 @@ async function getRawImage(image) {
}
}
/**
* Gets the model to use for a given transformers.js task.
* @param {string} task The task to get the model for
* @returns {string} The model to use for the given task
*/
function getModelForTask(task) {
const defaultModel = tasks[task].defaultModel;
@ -55,6 +70,11 @@ function getModelForTask(task) {
}
}
/**
* Gets the transformers.js pipeline for a given task.
* @param {string} task The task to get the pipeline for
* @returns {Promise<Pipeline>} Pipeline for the task
*/
async function getPipeline(task) {
if (tasks[task].pipeline) {
return tasks[task].pipeline;