CFG: Use ST prompt builder for negatives

Make the generate function build a negative prompt in addition to the
normal one. This allows for nonconflicting insertion with other extension
prompts and World Info.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-08-20 00:44:39 -04:00
parent 7191f7a8ad
commit 92e6c6a998
4 changed files with 107 additions and 100 deletions

View File

@@ -164,6 +164,7 @@ import { deviceInfo } from "./scripts/RossAscends-mods.js";
import { registerPromptManagerMigration } from "./scripts/PromptManager.js"; import { registerPromptManagerMigration } from "./scripts/PromptManager.js";
import { getRegexedString, regex_placement } from "./scripts/extensions/regex/engine.js"; import { getRegexedString, regex_placement } from "./scripts/extensions/regex/engine.js";
import { FILTER_TYPES, FilterHelper } from "./scripts/filters.js"; import { FILTER_TYPES, FilterHelper } from "./scripts/filters.js";
import { getCfgPrompt, getGuidanceScale } from "./scripts/extensions/cfg/util.js";
//exporting functions and vars for mods //exporting functions and vars for mods
export { export {
@@ -2762,6 +2763,7 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
} }
const anchorDepth = Math.abs(i - arrMes.length + 1); const anchorDepth = Math.abs(i - arrMes.length + 1);
// NOTE: Depth injected here!
const extensionAnchor = getExtensionPrompt(extension_prompt_types.IN_CHAT, anchorDepth); const extensionAnchor = getExtensionPrompt(extension_prompt_types.IN_CHAT, anchorDepth);
if (anchorDepth > 0 && extensionAnchor && extensionAnchor.length) { if (anchorDepth > 0 && extensionAnchor && extensionAnchor.length) {
@@ -2773,7 +2775,6 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
} }
let mesExmString = ''; let mesExmString = '';
let mesSendString = '';
function setPromtString() { function setPromtString() {
if (main_api == 'openai') { if (main_api == 'openai') {
@@ -2782,65 +2783,57 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
console.debug('--setting Prompt string'); console.debug('--setting Prompt string');
mesExmString = pinExmString ?? mesExamplesArray.slice(0, count_exm_add).join(''); mesExmString = pinExmString ?? mesExamplesArray.slice(0, count_exm_add).join('');
mesSendString = ''; mesSend[mesSend.length - 1] = modifyLastPromptLine(mesSend[mesSend.length - 1]);
for (let j = 0; j < mesSend.length; j++) {
const isBottom = j === mesSend.length - 1;
mesSendString += mesSend[j];
if (isBottom) {
mesSendString = modifyLastPromptLine(mesSendString);
}
}
} }
function modifyLastPromptLine(mesSendString) { function modifyLastPromptLine(lastMesString) {
// Add quiet generation prompt at depth 0 // Add quiet generation prompt at depth 0
if (quiet_prompt && quiet_prompt.length) { if (quiet_prompt && quiet_prompt.length) {
const name = is_pygmalion ? 'You' : name1; const name = is_pygmalion ? 'You' : name1;
const quietAppend = isInstruct ? formatInstructModeChat(name, quiet_prompt, false, true, false, name1, name2) : `\n${name}: ${quiet_prompt}`; const quietAppend = isInstruct ? formatInstructModeChat(name, quiet_prompt, false, true, false, name1, name2) : `\n${name}: ${quiet_prompt}`;
mesSendString += quietAppend; lastMesString += quietAppend;
// Bail out early // Bail out early
return mesSendString; return lastMesString;
} }
// Get instruct mode line // Get instruct mode line
if (isInstruct && tokens_already_generated === 0) { if (isInstruct && tokens_already_generated === 0) {
const name = isImpersonate ? (is_pygmalion ? 'You' : name1) : name2; const name = isImpersonate ? (is_pygmalion ? 'You' : name1) : name2;
mesSendString += formatInstructModePrompt(name, isImpersonate, promptBias, name1, name2); lastMesString += formatInstructModePrompt(name, isImpersonate, promptBias, name1, name2);
} }
// Get non-instruct impersonation line // Get non-instruct impersonation line
if (!isInstruct && isImpersonate && tokens_already_generated === 0) { if (!isInstruct && isImpersonate && tokens_already_generated === 0) {
const name = is_pygmalion ? 'You' : name1; const name = is_pygmalion ? 'You' : name1;
if (!mesSendString.endsWith('\n')) { if (!lastMesString.endsWith('\n')) {
mesSendString += '\n'; lastMesString += '\n';
} }
mesSendString += name + ':'; lastMesString += name + ':';
} }
// Add character's name // Add character's name
if (!isInstruct && force_name2 && tokens_already_generated === 0) { if (!isInstruct && force_name2 && tokens_already_generated === 0) {
if (!mesSendString.endsWith('\n')) { if (!lastMesString.endsWith('\n')) {
mesSendString += '\n'; lastMesString += '\n';
} }
// Add a leading space to the prompt bias if applicable // Add a leading space to the prompt bias if applicable
if (!promptBias || promptBias.length === 0) { if (!promptBias || promptBias.length === 0) {
console.debug("No prompt bias was found."); console.debug("No prompt bias was found.");
mesSendString += `${name2}:`; lastMesString += `${name2}:`;
} else if (promptBias.startsWith(' ')) { } else if (promptBias.startsWith(' ')) {
console.debug(`A prompt bias with a leading space was found: ${promptBias}`); console.debug(`A prompt bias with a leading space was found: ${promptBias}`);
mesSendString += `${name2}:${promptBias}` lastMesString += `${name2}:${promptBias}`
} else { } else {
console.debug(`A prompt bias was found: ${promptBias}`); console.debug(`A prompt bias was found: ${promptBias}`);
mesSendString += `${name2}: ${promptBias}`; lastMesString += `${name2}: ${promptBias}`;
} }
} else if (power_user.user_prompt_bias && !isImpersonate && !isInstruct) { } else if (power_user.user_prompt_bias && !isImpersonate && !isInstruct) {
console.debug(`A prompt bias was found without character's name appended: ${promptBias}`); console.debug(`A prompt bias was found without character's name appended: ${promptBias}`);
mesSendString += substituteParams(power_user.user_prompt_bias); lastMesString += substituteParams(power_user.user_prompt_bias);
} }
return mesSendString; return lastMesString;
} }
function checkPromtSize() { function checkPromtSize() {
@@ -2849,7 +2842,7 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
const prompt = [ const prompt = [
storyString, storyString,
mesExmString, mesExmString,
mesSendString, mesSend.join(''),
generatedPromtCache, generatedPromtCache,
allAnchors, allAnchors,
quiet_prompt, quiet_prompt,
@@ -2878,30 +2871,60 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
setPromtString(); setPromtString();
} }
const cfgGuidanceScale = getGuidanceScale();
function getCombinedPrompt(isNegative) {
if (isNegative && cfgGuidanceScale !== 1) {
const negativePrompt = getCfgPrompt(cfgGuidanceScale);
if (negativePrompt && negativePrompt?.value) {
// TODO: kingbri: use the insertion depth method instead of splicing
mesSend.splice(mesSend.length - negativePrompt.depth, 0, `${negativePrompt.value}\n`);
}
}
let mesSendString = mesSend.join('');
// add chat preamble // add chat preamble
mesSendString = addChatsPreamble(mesSendString); mesSendString = addChatsPreamble(mesSendString);
// add a custom dingus (if defined) // add a custom dingus (if defined)
mesSendString = addChatsSeparator(mesSendString); mesSendString = addChatsSeparator(mesSendString);
let finalPromt = if (zeroDepthAnchor && zeroDepthAnchor.length) {
if (!isMultigenEnabled() || tokens_already_generated == 0) {
combinedPrompt = appendZeroDepthAnchor(force_name2, zeroDepthAnchor, combinedPrompt);
}
}
let combinedPrompt =
storyString + storyString +
afterScenarioAnchor + afterScenarioAnchor +
mesExmString + mesExmString +
mesSendString + mesSendString +
generatedPromtCache; generatedPromtCache;
if (zeroDepthAnchor && zeroDepthAnchor.length) { combinedPrompt = combinedPrompt.replace(/\r/gm, '');
if (!isMultigenEnabled() || tokens_already_generated == 0) {
finalPromt = appendZeroDepthAnchor(force_name2, zeroDepthAnchor, finalPromt);
}
}
finalPromt = finalPromt.replace(/\r/gm, '');
if (power_user.collapse_newlines) { if (power_user.collapse_newlines) {
finalPromt = collapseNewlines(finalPromt); combinedPrompt = collapseNewlines(combinedPrompt);
} }
return combinedPrompt;
}
let mesSendString = mesSend.join('');
// add chat preamble
mesSendString = addChatsPreamble(mesSendString);
// add a custom dingus (if defined)
mesSendString = addChatsSeparator(mesSendString);
let finalPromt = getCombinedPrompt(false);
let negativePrompt = getCombinedPrompt(true);
const cfgValues = {
guidanceScale: cfgGuidanceScale?.value,
negativePrompt: negativePrompt
};
let this_amount_gen = parseInt(amount_gen); // how many tokens the AI will be requested to generate let this_amount_gen = parseInt(amount_gen); // how many tokens the AI will be requested to generate
let this_settings = koboldai_settings[koboldai_setting_names[preset_settings]]; let this_settings = koboldai_settings[koboldai_setting_names[preset_settings]];
@@ -2935,12 +2958,12 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
} }
} }
else if (main_api == 'textgenerationwebui') { else if (main_api == 'textgenerationwebui') {
generate_data = getTextGenGenerationData(finalPromt, this_amount_gen, isImpersonate); generate_data = getTextGenGenerationData(finalPromt, this_amount_gen, isImpersonate, cfgValues);
generate_data.use_mancer = api_use_mancer_webui; generate_data.use_mancer = api_use_mancer_webui;
} }
else if (main_api == 'novel') { else if (main_api == 'novel') {
const this_settings = novelai_settings[novelai_setting_names[nai_settings.preset_settings_novel]]; const this_settings = novelai_settings[novelai_setting_names[nai_settings.preset_settings_novel]];
generate_data = getNovelGenerationData(finalPromt, this_settings, this_amount_gen, isImpersonate); generate_data = getNovelGenerationData(finalPromt, this_settings, this_amount_gen, isImpersonate, cfgValues);
} }
else if (main_api == 'openai') { else if (main_api == 'openai') {
let [prompt, counts] = prepareOpenAIMessages({ let [prompt, counts] = prepareOpenAIMessages({

View File

@@ -17,50 +17,13 @@ export const metadataKeys = {
negative_separator: "cfg_negative_separator" negative_separator: "cfg_negative_separator"
} }
// Gets the CFG value from hierarchy of chat -> character -> global // Gets the CFG guidance scale
// Returns undefined values which should be handled in the respective backend APIs
// TODO: Maybe use existing prompt building/substitution?
// TODO: Insertion depth conflicts with author's note. Shouldn't matter though since CFG is prompt mixing.
export function getCfg(prompt) {
const splitPrompt = prompt?.split("\n") ?? [];
let splitNegativePrompt = [];
const charaCfg = extension_settings.cfg.chara?.find((e) => e.name === getCharaFilename(this_chid));
const guidanceScale = getGuidanceScale(charaCfg);
const chatNegativeCombine = chat_metadata[metadataKeys.negative_combine] ?? [];
// If there's a guidance scale, continue. Otherwise assume undefined
if (guidanceScale?.value && guidanceScale?.value !== 1) {
if (guidanceScale.type === cfgType.chat || chatNegativeCombine.includes(cfgType.chat)) {
splitNegativePrompt.unshift(substituteParams(chat_metadata[metadataKeys.negative_prompt])?.trim());
}
if (guidanceScale.type === cfgType.chara || chatNegativeCombine.includes(cfgType.chara)) {
splitNegativePrompt.unshift(substituteParams(charaCfg.negative_prompt)?.trim())
}
if (guidanceScale.type === cfgType.global || chatNegativeCombine.includes(cfgType.global)) {
splitNegativePrompt.unshift(substituteParams(extension_settings.cfg.global.negative_prompt)?.trim());
}
// This line is a bit hacky with a JSON.stringify and JSON.parse. Fix this if possible.
const negativeSeparator = JSON.parse(chat_metadata[metadataKeys.negative_separator] || JSON.stringify("\n")) ?? "\n";
const combinedNegatives = splitNegativePrompt.filter((e) => e.length > 0).join(negativeSeparator);
const insertionDepth = chat_metadata[metadataKeys.negative_insertion_depth] ?? 1;
console.log(insertionDepth)
splitPrompt.splice(splitPrompt.length - insertionDepth, 0, combinedNegatives);
console.log(`Setting CFG with guidance scale: ${guidanceScale.value}, negatives: ${combinedNegatives}`);
return {
guidanceScale: guidanceScale.value,
negativePrompt: splitPrompt.join("\n")
}
}
}
// If the guidance scale is 1, ignore the CFG negative prompt since it won't be used anyways // If the guidance scale is 1, ignore the CFG negative prompt since it won't be used anyways
function getGuidanceScale(charaCfg) { export function getGuidanceScale() {
const charaCfg = extension_settings.cfg.chara?.find((e) => e.name === getCharaFilename(this_chid));
const chatGuidanceScale = chat_metadata[metadataKeys.guidance_scale]; const chatGuidanceScale = chat_metadata[metadataKeys.guidance_scale];
const groupchatCharOverride = chat_metadata[metadataKeys.groupchat_individual_chars] ?? false; const groupchatCharOverride = chat_metadata[metadataKeys.groupchat_individual_chars] ?? false;
if (chatGuidanceScale && chatGuidanceScale !== 1 && !groupchatCharOverride) { if (chatGuidanceScale && chatGuidanceScale !== 1 && !groupchatCharOverride) {
return { return {
type: cfgType.chat, type: cfgType.chat,
@@ -80,3 +43,33 @@ function getGuidanceScale(charaCfg) {
value: extension_settings.cfg.global.guidance_scale value: extension_settings.cfg.global.guidance_scale
}; };
} }
// Gets the CFG prompt. Currently only gets the negative prompt
export function getCfgPrompt(guidanceScale) {
let splitNegativePrompt = [];
const chatNegativeCombine = chat_metadata[metadataKeys.negative_combine] ?? [];
if (guidanceScale.type === cfgType.chat || chatNegativeCombine.includes(cfgType.chat)) {
splitNegativePrompt.unshift(substituteParams(chat_metadata[metadataKeys.negative_prompt])?.trim());
}
const charaCfg = extension_settings.cfg.chara?.find((e) => e.name === getCharaFilename(this_chid));
if (guidanceScale.type === cfgType.chara || chatNegativeCombine.includes(cfgType.chara)) {
splitNegativePrompt.unshift(substituteParams(charaCfg.negative_prompt)?.trim())
}
if (guidanceScale.type === cfgType.global || chatNegativeCombine.includes(cfgType.global)) {
splitNegativePrompt.unshift(substituteParams(extension_settings.cfg.global.negative_prompt)?.trim());
}
// This line is a bit hacky with a JSON.stringify and JSON.parse. Fix this if possible.
const negativeSeparator = JSON.parse(chat_metadata[metadataKeys.negative_separator] || JSON.stringify("\n")) ?? "\n";
const combinedNegatives = splitNegativePrompt.filter((e) => e.length > 0).join(negativeSeparator);
const insertionDepth = chat_metadata[metadataKeys.negative_insertion_depth] ?? 1;
console.log(`Setting CFG with guidance scale: ${guidanceScale.value}, negatives: ${combinedNegatives}`);
return {
value: combinedNegatives,
depth: insertionDepth
};
}

View File

@@ -7,7 +7,6 @@ import {
saveSettingsDebounced, saveSettingsDebounced,
setGenerationParamsFromPreset setGenerationParamsFromPreset
} from "../script.js"; } from "../script.js";
import { getCfg } from "./extensions/cfg/util.js";
import { MAX_CONTEXT_DEFAULT, tokenizers } from "./power-user.js"; import { MAX_CONTEXT_DEFAULT, tokenizers } from "./power-user.js";
import { import {
getSortableDelay, getSortableDelay,
@@ -395,7 +394,7 @@ function getBadWordPermutations(text) {
return result; return result;
} }
export function getNovelGenerationData(finalPrompt, this_settings, this_amount_gen, isImpersonate) { export function getNovelGenerationData(finalPrompt, this_settings, this_amount_gen, isImpersonate, cfgValues) {
const clio = nai_settings.model_novel.includes('clio'); const clio = nai_settings.model_novel.includes('clio');
const kayra = nai_settings.model_novel.includes('kayra'); const kayra = nai_settings.model_novel.includes('kayra');
@@ -410,7 +409,6 @@ export function getNovelGenerationData(finalPrompt, this_settings, this_amount_g
: undefined; : undefined;
const prefix = selectPrefix(nai_settings.prefix, finalPrompt); const prefix = selectPrefix(nai_settings.prefix, finalPrompt);
const cfgSettings = getCfg(finalPrompt);
let logitBias = []; let logitBias = [];
if (tokenizerType !== tokenizers.NONE && Array.isArray(nai_settings.logit_bias) && nai_settings.logit_bias.length) { if (tokenizerType !== tokenizers.NONE && Array.isArray(nai_settings.logit_bias) && nai_settings.logit_bias.length) {
@@ -437,8 +435,8 @@ export function getNovelGenerationData(finalPrompt, this_settings, this_amount_g
"typical_p": parseFloat(nai_settings.typical_p), "typical_p": parseFloat(nai_settings.typical_p),
"mirostat_lr": parseFloat(nai_settings.mirostat_lr), "mirostat_lr": parseFloat(nai_settings.mirostat_lr),
"mirostat_tau": parseFloat(nai_settings.mirostat_tau), "mirostat_tau": parseFloat(nai_settings.mirostat_tau),
"cfg_scale": cfgSettings?.guidanceScale ?? parseFloat(nai_settings.cfg_scale), "cfg_scale": cfgValues?.guidanceScale ?? parseFloat(nai_settings.cfg_scale),
"cfg_uc": cfgSettings?.negativePrompt ?? nai_settings.cfg_uc ?? "", "cfg_uc": cfgValues?.negativePrompt ?? nai_settings.cfg_uc ?? "",
"phrase_rep_pen": nai_settings.phrase_rep_pen, "phrase_rep_pen": nai_settings.phrase_rep_pen,
"stop_sequences": stopSequences, "stop_sequences": stopSequences,
"bad_words_ids": badWordIds, "bad_words_ids": badWordIds,

View File

@@ -6,8 +6,6 @@ import {
setGenerationParamsFromPreset, setGenerationParamsFromPreset,
} from "../script.js"; } from "../script.js";
import { getCfg } from "./extensions/cfg/util.js";
import { import {
power_user, power_user,
} from "./power-user.js"; } from "./power-user.js";
@@ -235,12 +233,7 @@ async function generateTextGenWithStreaming(generate_data, signal) {
} }
} }
export function getTextGenGenerationData(finalPromt, this_amount_gen, isImpersonate) { export function getTextGenGenerationData(finalPromt, this_amount_gen, isImpersonate, cfgValues) {
let cfgValues = {};
if (!isImpersonate) {
cfgValues = getCfg(finalPromt);
}
return { return {
'prompt': finalPromt, 'prompt': finalPromt,
'max_new_tokens': this_amount_gen, 'max_new_tokens': this_amount_gen,
@@ -258,8 +251,8 @@ export function getTextGenGenerationData(finalPromt, this_amount_gen, isImperson
'penalty_alpha': textgenerationwebui_settings.penalty_alpha, 'penalty_alpha': textgenerationwebui_settings.penalty_alpha,
'length_penalty': textgenerationwebui_settings.length_penalty, 'length_penalty': textgenerationwebui_settings.length_penalty,
'early_stopping': textgenerationwebui_settings.early_stopping, 'early_stopping': textgenerationwebui_settings.early_stopping,
'guidance_scale': cfgValues?.guidanceScale ?? textgenerationwebui_settings.guidance_scale ?? 1, 'guidance_scale': isImpersonate ? 1 : cfgValues?.guidanceScale ?? textgenerationwebui_settings.guidance_scale ?? 1,
'negative_prompt': cfgValues?.negativePrompt ?? textgenerationwebui_settings.negative_prompt ?? '', 'negative_prompt': isImpersonate ? '' : cfgValues?.negativePrompt ?? textgenerationwebui_settings.negative_prompt ?? '',
'seed': textgenerationwebui_settings.seed, 'seed': textgenerationwebui_settings.seed,
'add_bos_token': textgenerationwebui_settings.add_bos_token, 'add_bos_token': textgenerationwebui_settings.add_bos_token,
'stopping_strings': getStoppingStrings(isImpersonate, false), 'stopping_strings': getStoppingStrings(isImpersonate, false),