From cdbca6d9fde09bc3ad429a99130cfe1cf887b3ec Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 17 Aug 2023 23:51:17 -0400 Subject: [PATCH] CFG: Include the entire prompt with negative prompt CFG with LLMs works differently than stable diffusion. The main principle is prompt mixing and utilizing the differences between the two prompts rather than a full "negative prompt" of what the user doesn't want. SillyTavern its own way of formatting a prompt sent to an LLM backend. Therefore, take that prompt and add negatives to it. Signed-off-by: kingbri --- public/script.js | 6 +++++- public/scripts/extensions/cfg/index.js | 1 + public/scripts/extensions/cfg/util.js | 27 +++++++++++++++++++++----- public/scripts/textgen-settings.js | 4 +--- 4 files changed, 29 insertions(+), 9 deletions(-) diff --git a/public/script.js b/public/script.js index 97b22af2c..430deda8f 100644 --- a/public/script.js +++ b/public/script.js @@ -164,6 +164,7 @@ import { deviceInfo } from "./scripts/RossAscends-mods.js"; import { registerPromptManagerMigration } from "./scripts/PromptManager.js"; import { getRegexedString, regex_placement } from "./scripts/extensions/regex/engine.js"; import { FILTER_TYPES, FilterHelper } from "./scripts/filters.js"; +import { getCfg, getNegativePrompt } from "./scripts/extensions/cfg/util.js"; //exporting functions and vars for mods export { @@ -2905,6 +2906,8 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject, let this_amount_gen = parseInt(amount_gen); // how many tokens the AI will be requested to generate let this_settings = koboldai_settings[koboldai_setting_names[preset_settings]]; + const cfgValues = getCfg(finalPromt); + if (isMultigenEnabled() && type !== 'quiet') { // if nothing has been generated yet.. this_amount_gen = getMultigenAmount(); @@ -2912,6 +2915,7 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject, let thisPromptBits = []; + // TODO: Make this a switch if (main_api == 'koboldhorde' && horde_settings.auto_adjust_response_length) { this_amount_gen = Math.min(this_amount_gen, adjustedParams.maxLength); this_amount_gen = Math.max(this_amount_gen, MIN_AMOUNT_GEN); // prevent validation errors @@ -2934,7 +2938,7 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject, } } else if (main_api == 'textgenerationwebui') { - generate_data = getTextGenGenerationData(finalPromt, this_amount_gen, isImpersonate); + generate_data = getTextGenGenerationData(finalPromt, this_amount_gen, isImpersonate, cfgValues); generate_data.use_mancer = api_use_mancer_webui; } else if (main_api == 'novel') { diff --git a/public/scripts/extensions/cfg/index.js b/public/scripts/extensions/cfg/index.js index 31f2af7dc..a5b1f5773 100644 --- a/public/scripts/extensions/cfg/index.js +++ b/public/scripts/extensions/cfg/index.js @@ -273,6 +273,7 @@ jQuery(async () => { saveSettingsDebounced(); }); + // TODO: Add negative insertion depth windowHtml.find('#global_cfg_negative_prompt').on('input', function() { extension_settings.cfg.global.negative_prompt = $(this).val(); saveSettingsDebounced(); diff --git a/public/scripts/extensions/cfg/util.js b/public/scripts/extensions/cfg/util.js index 8637ced3a..0f6cdfb75 100644 --- a/public/scripts/extensions/cfg/util.js +++ b/public/scripts/extensions/cfg/util.js @@ -12,18 +12,23 @@ export const metadataKeys = { guidance_scale: "cfg_guidance_scale", negative_prompt: "cfg_negative_prompt", negative_combine: "cfg_negative_combine", - groupchat_individual_chars: "cfg_groupchat_individual_chars" + groupchat_individual_chars: "cfg_groupchat_individual_chars", + negative_insertion_depth: "cfg_negative_insertion_depth" } // Gets the CFG value from hierarchy of chat -> character -> global // Returns undefined values which should be handled in the respective backend APIs -export function getCfg() { +// TODO: Include a custom negative separator +// TODO: Maybe use existing prompt building/substitution? +export function getCfg(prompt) { + const splitPrompt = prompt?.split("\n") ?? []; let splitNegativePrompt = []; const charaCfg = extension_settings.cfg.chara?.find((e) => e.name === getCharaFilename(this_chid)); const guidanceScale = getGuidanceScale(charaCfg); const chatNegativeCombine = chat_metadata[metadataKeys.negative_combine] ?? []; // If there's a guidance scale, continue. Otherwise assume undefined + // TODO: Run substitute params if (guidanceScale?.value && guidanceScale?.value !== 1) { if (guidanceScale.type === cfgType.chat || chatNegativeCombine.includes(cfgType.chat)) { splitNegativePrompt.push(chat_metadata[metadataKeys.negative_prompt]?.trim()); @@ -37,12 +42,15 @@ export function getCfg() { splitNegativePrompt.push(extension_settings.cfg.global.negative_prompt?.trim()); } - const combinedNegatives = splitNegativePrompt.filter((e) => e.length > 0).join(", "); - console.debug(`Setting CFG with guidance scale: ${guidanceScale.value}, negatives: ${combinedNegatives}`) + // TODO: use a custom separator for join + const combinedNegatives = splitNegativePrompt.filter((e) => e.length > 0).join("\n"); + const insertionDepth = chat_metadata[metadataKeys.negative_insertion_depth] ?? 1; + splitPrompt.splice(splitPrompt.length - insertionDepth, 0, combinedNegatives); + console.log(`Setting CFG with guidance scale: ${guidanceScale.value}, negatives: ${combinedNegatives}`); return { guidanceScale: guidanceScale.value, - negativePrompt: combinedNegatives + negativePrompt: splitPrompt.join("\n") } } } @@ -70,3 +78,12 @@ function getGuidanceScale(charaCfg) { value: extension_settings.cfg.global.guidance_scale }; } + +export function getNegativePrompt(prompt) { + const splitPrompt = prompt.split("\n"); + const insertionDepth = chat_metadata[metadataKeys.negative_insertion_depth] ?? 1; + splitPrompt.splice(splitPrompt.length - insertionDepth, 0, "Test negative list"); + console.log(splitPrompt); + const negativePrompt = splitPrompt.join("\n"); + //console.log(negativePrompt); +} diff --git a/public/scripts/textgen-settings.js b/public/scripts/textgen-settings.js index 9ac736306..48a7c162e 100644 --- a/public/scripts/textgen-settings.js +++ b/public/scripts/textgen-settings.js @@ -235,9 +235,7 @@ async function generateTextGenWithStreaming(generate_data, signal) { } } -export function getTextGenGenerationData(finalPromt, this_amount_gen, isImpersonate) { - const cfgValues = getCfg(); - +export function getTextGenGenerationData(finalPromt, this_amount_gen, isImpersonate, cfgValues) { return { 'prompt': finalPromt, 'max_new_tokens': this_amount_gen,