mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-03-30 03:40:16 +02:00
Add SD prompt expansion
This commit is contained in:
parent
5c6343e85e
commit
c4e6b565a5
@ -26,6 +26,8 @@ const extras = {
|
||||
captioningModel: 'Xenova/vit-gpt2-image-captioning',
|
||||
// Feature extraction model. HuggingFace ID of a model in ONNX format.
|
||||
embeddingModel: 'Xenova/all-mpnet-base-v2',
|
||||
// GPT-2 text generation model. HuggingFace ID of a model in ONNX format.
|
||||
promptExpansionModel: 'Cohee/fooocus_expansion-onnx',
|
||||
};
|
||||
|
||||
// Request overrides for additional headers
|
||||
|
@ -160,6 +160,7 @@ const defaultSettings = {
|
||||
|
||||
// Refine mode
|
||||
refine_mode: false,
|
||||
expand: false,
|
||||
|
||||
prompts: promptTemplates,
|
||||
|
||||
@ -257,6 +258,7 @@ async function loadSettings() {
|
||||
$('#sd_restore_faces').prop('checked', extension_settings.sd.restore_faces);
|
||||
$('#sd_enable_hr').prop('checked', extension_settings.sd.enable_hr);
|
||||
$('#sd_refine_mode').prop('checked', extension_settings.sd.refine_mode);
|
||||
$('#sd_expand').prop('checked', extension_settings.sd.expand);
|
||||
$('#sd_auto_url').val(extension_settings.sd.auto_url);
|
||||
$('#sd_auto_auth').val(extension_settings.sd.auto_auth);
|
||||
$('#sd_vlad_url').val(extension_settings.sd.vlad_url);
|
||||
@ -300,7 +302,30 @@ function addPromptTemplates() {
|
||||
}
|
||||
}
|
||||
|
||||
async function expandPrompt(prompt) {
|
||||
try {
|
||||
const response = await fetch('/api/sd/expand', {
|
||||
method: 'POST',
|
||||
headers: getRequestHeaders(),
|
||||
body: JSON.stringify({ prompt: prompt }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('API returned an error.');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
return data.prompt;
|
||||
} catch {
|
||||
return prompt;
|
||||
}
|
||||
}
|
||||
|
||||
async function refinePrompt(prompt) {
|
||||
if (extension_settings.sd.expand) {
|
||||
prompt = await expandPrompt(prompt);
|
||||
}
|
||||
|
||||
if (extension_settings.sd.refine_mode) {
|
||||
const refinedPrompt = await callPopup('<h3>Review and edit the prompt:</h3>Press "Cancel" to abort the image generation.', 'input', prompt.trim(), { rows: 5, okButton: 'Generate' });
|
||||
|
||||
@ -361,6 +386,11 @@ function combinePrefixes(str1, str2) {
|
||||
return result;
|
||||
}
|
||||
|
||||
function onExpandInput() {
|
||||
extension_settings.sd.expand = !!$(this).prop('checked');
|
||||
saveSettingsDebounced();
|
||||
}
|
||||
|
||||
function onRefineModeInput() {
|
||||
extension_settings.sd.refine_mode = !!$('#sd_refine_mode').prop('checked');
|
||||
saveSettingsDebounced();
|
||||
@ -1610,6 +1640,7 @@ jQuery(async () => {
|
||||
$('#sd_novel_upscale_ratio').on('input', onNovelUpscaleRatioInput);
|
||||
$('#sd_novel_anlas_guard').on('input', onNovelAnlasGuardInput);
|
||||
$('#sd_novel_view_anlas').on('click', onViewAnlasClick);
|
||||
$('#sd_expand').on('input', onExpandInput);
|
||||
$('#sd_character_prompt_block').hide();
|
||||
|
||||
$('.sd_settings .inline-drawer-toggle').on('click', function () {
|
||||
|
@ -12,6 +12,10 @@
|
||||
<input id="sd_refine_mode" type="checkbox" />
|
||||
Edit prompts before generation
|
||||
</label>
|
||||
<label for="sd_expand" class="checkbox_label" title="Automatically extend prompts using text generation model">
|
||||
<input id="sd_expand" type="checkbox" />
|
||||
Auto-enhance prompts
|
||||
</label>
|
||||
<label for="sd_source">Source</label>
|
||||
<select id="sd_source">
|
||||
<option value="extras">Extras API (local / remote)</option>
|
||||
|
@ -1,6 +1,43 @@
|
||||
const fetch = require('node-fetch').default;
|
||||
const { getBasicAuthHeader, delay } = require('./util');
|
||||
|
||||
/**
|
||||
* Sanitizes a string.
|
||||
* @param {string} x String to sanitize
|
||||
* @returns {string} Sanitized string
|
||||
*/
|
||||
function safeStr(x) {
|
||||
x = String(x);
|
||||
for (let i = 0; i < 16; i++) {
|
||||
x = x.replace(/ /g, ' ');
|
||||
}
|
||||
x = x.trim();
|
||||
x = x.replace(/^[\s,.]+|[\s,.]+$/g, '');
|
||||
return x;
|
||||
}
|
||||
|
||||
const splitStrings = [
|
||||
', extremely',
|
||||
', intricate,',
|
||||
];
|
||||
|
||||
const dangerousPatterns = '[]【】()()|::';
|
||||
|
||||
/**
|
||||
* Removes patterns from a string.
|
||||
* @param {string} x String to sanitize
|
||||
* @param {string} pattern Pattern to remove
|
||||
* @returns {string} Sanitized string
|
||||
*/
|
||||
function removePattern(x, pattern) {
|
||||
for (let i = 0; i < pattern.length; i++) {
|
||||
let p = pattern[i];
|
||||
let regex = new RegExp("\\" + p, 'g');
|
||||
x = x.replace(regex, '');
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers the endpoints for the Stable Diffusion API extension.
|
||||
* @param {import("express").Express} app Express app
|
||||
@ -275,6 +312,40 @@ function registerEndpoints(app, jsonParser) {
|
||||
return response.sendStatus(500);
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* SD prompt expansion using GPT-2 text generation model.
|
||||
* Adapted from: https://github.com/lllyasviel/Fooocus/blob/main/modules/expansion.py
|
||||
*/
|
||||
app.post('/api/sd/expand', jsonParser, async (request, response) => {
|
||||
const originalPrompt = request.body.prompt;
|
||||
|
||||
if (!originalPrompt) {
|
||||
console.warn('No prompt provided for SD expansion.');
|
||||
return response.send({ prompt: '' });
|
||||
}
|
||||
|
||||
console.log('Refine prompt input:', originalPrompt);
|
||||
const splitString = splitStrings[Math.floor(Math.random() * splitStrings.length)];
|
||||
let prompt = safeStr(originalPrompt) + splitString;
|
||||
|
||||
try {
|
||||
const task = 'text-generation';
|
||||
const module = await import('./transformers.mjs');
|
||||
const pipe = await module.default.getPipeline(task);
|
||||
|
||||
const result = await pipe(prompt, { num_beams: 1, max_new_tokens: 256, do_sample: true });
|
||||
|
||||
const newText = result[0].generated_text;
|
||||
const newPrompt = safeStr(removePattern(newText, dangerousPatterns));
|
||||
console.log('Refine prompt output:', newPrompt);
|
||||
|
||||
return response.send({ prompt: newPrompt });
|
||||
} catch {
|
||||
console.warn('Failed to load transformers.js pipeline.');
|
||||
return response.send({ prompt: originalPrompt });
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { pipeline, env, RawImage } from 'sillytavern-transformers';
|
||||
import { pipeline, env, RawImage, Pipeline } from 'sillytavern-transformers';
|
||||
import { getConfigValue } from './util.js';
|
||||
import path from 'path';
|
||||
import _ from 'lodash';
|
||||
@ -28,8 +28,18 @@ const tasks = {
|
||||
pipeline: null,
|
||||
configField: 'extras.embeddingModel',
|
||||
},
|
||||
'text-generation': {
|
||||
defaultModel: 'Cohee/fooocus_expansion-onnx',
|
||||
pipeline: null,
|
||||
configField: 'extras.promptExpansionModel',
|
||||
},
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a RawImage object from a base64-encoded image.
|
||||
* @param {string} image Base64-encoded image
|
||||
* @returns {Promise<RawImage|null>} Object representing the image
|
||||
*/
|
||||
async function getRawImage(image) {
|
||||
try {
|
||||
const buffer = Buffer.from(image, 'base64');
|
||||
@ -43,6 +53,11 @@ async function getRawImage(image) {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the model to use for a given transformers.js task.
|
||||
* @param {string} task The task to get the model for
|
||||
* @returns {string} The model to use for the given task
|
||||
*/
|
||||
function getModelForTask(task) {
|
||||
const defaultModel = tasks[task].defaultModel;
|
||||
|
||||
@ -55,6 +70,11 @@ function getModelForTask(task) {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the transformers.js pipeline for a given task.
|
||||
* @param {string} task The task to get the pipeline for
|
||||
* @returns {Promise<Pipeline>} Pipeline for the task
|
||||
*/
|
||||
async function getPipeline(task) {
|
||||
if (tasks[task].pipeline) {
|
||||
return tasks[task].pipeline;
|
||||
|
Loading…
x
Reference in New Issue
Block a user