Add prompt injection filters

This commit is contained in:
Cohee
2024-12-06 19:53:02 +02:00
parent 9c43999e4b
commit d6f34f7b2c
3 changed files with 86 additions and 40 deletions

View File

@ -2875,23 +2875,54 @@ function addPersonaDescriptionExtensionPrompt() {
} }
} }
function getAllExtensionPrompts() { /**
const value = Object * Returns all extension prompts combined.
.values(extension_prompts) * @returns {Promise<string>} Combined extension prompts
.filter(x => x.value) */
.map(x => x.value.trim()) async function getAllExtensionPrompts() {
.join('\n'); const values = [];
return value.length ? substituteParams(value) : ''; for (const prompt of Object.values(extension_prompts)) {
const value = prompt?.value?.trim();
if (!value) {
continue;
}
const hasFilter = typeof prompt.filter === 'function';
if (hasFilter && !await prompt.filter()) {
continue;
}
values.push(value);
}
return substituteParams(values.join('\n'));
} }
// Wrapper to fetch extension prompts by module name /**
export function getExtensionPromptByName(moduleName) { * Wrapper to fetch extension prompts by module name
if (moduleName) { * @param {string} moduleName Module name
return substituteParams(extension_prompts[moduleName]?.value); * @returns {Promise<string>} Extension prompt
} else { */
return; export async function getExtensionPromptByName(moduleName) {
if (!moduleName) {
return '';
} }
const prompt = extension_prompts[moduleName];
if (!prompt) {
return '';
}
const hasFilter = typeof prompt.filter === 'function';
if (hasFilter && !await prompt.filter()) {
return '';
}
return substituteParams(prompt.value);
} }
/** /**
@ -2902,27 +2933,36 @@ export function getExtensionPromptByName(moduleName) {
* @param {string} [separator] Separator for joining multiple prompts * @param {string} [separator] Separator for joining multiple prompts
* @param {number} [role] Role of the prompt * @param {number} [role] Role of the prompt
* @param {boolean} [wrap] Wrap start and end with a separator * @param {boolean} [wrap] Wrap start and end with a separator
* @returns {string} Extension prompt * @returns {Promise<string>} Extension prompt
*/ */
export function getExtensionPrompt(position = extension_prompt_types.IN_PROMPT, depth = undefined, separator = '\n', role = undefined, wrap = true) { export async function getExtensionPrompt(position = extension_prompt_types.IN_PROMPT, depth = undefined, separator = '\n', role = undefined, wrap = true) {
let extension_prompt = Object.keys(extension_prompts) const filterByFunction = async (prompt) => {
const hasFilter = typeof prompt.filter === 'function';
if (hasFilter && !await prompt.filter()) {
return false;
}
return true;
};
const promptPromises = Object.keys(extension_prompts)
.sort() .sort()
.map((x) => extension_prompts[x]) .map((x) => extension_prompts[x])
.filter(x => x.position == position && x.value) .filter(x => x.position == position && x.value)
.filter(x => depth === undefined || x.depth === undefined || x.depth === depth) .filter(x => depth === undefined || x.depth === undefined || x.depth === depth)
.filter(x => role === undefined || x.role === undefined || x.role === role) .filter(x => role === undefined || x.role === undefined || x.role === role)
.map(x => x.value.trim()) .filter(filterByFunction);
.join(separator); const prompts = await Promise.all(promptPromises);
if (wrap && extension_prompt.length && !extension_prompt.startsWith(separator)) {
extension_prompt = separator + extension_prompt; let values = prompts.map(x => x.value.trim()).join(separator);
if (wrap && values.length && !values.startsWith(separator)) {
values = separator + values;
} }
if (wrap && extension_prompt.length && !extension_prompt.endsWith(separator)) { if (wrap && values.length && !values.endsWith(separator)) {
extension_prompt = extension_prompt + separator; values = values + separator;
} }
if (extension_prompt.length) { if (values.length) {
extension_prompt = substituteParams(extension_prompt); values = substituteParams(values);
} }
return extension_prompt; return values;
} }
export function baseChatReplace(value, name1, name2) { export function baseChatReplace(value, name1, name2) {
@ -3836,7 +3876,7 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
// Inject all Depth prompts. Chat Completion does it separately // Inject all Depth prompts. Chat Completion does it separately
let injectedIndices = []; let injectedIndices = [];
if (main_api !== 'openai') { if (main_api !== 'openai') {
injectedIndices = doChatInject(coreChat, isContinue); injectedIndices = await doChatInject(coreChat, isContinue);
} }
// Insert character jailbreak as the last user message (if exists, allowed, preferred, and not using Chat Completion) // Insert character jailbreak as the last user message (if exists, allowed, preferred, and not using Chat Completion)
@ -3909,8 +3949,8 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
} }
// Call combined AN into Generate // Call combined AN into Generate
const beforeScenarioAnchor = getExtensionPrompt(extension_prompt_types.BEFORE_PROMPT).trimStart(); const beforeScenarioAnchor = (await getExtensionPrompt(extension_prompt_types.BEFORE_PROMPT)).trimStart();
const afterScenarioAnchor = getExtensionPrompt(extension_prompt_types.IN_PROMPT); const afterScenarioAnchor = await getExtensionPrompt(extension_prompt_types.IN_PROMPT);
const storyStringParams = { const storyStringParams = {
description: description, description: description,
@ -4473,7 +4513,7 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
...thisPromptBits[currentArrayEntry], ...thisPromptBits[currentArrayEntry],
rawPrompt: generate_data.prompt || generate_data.input, rawPrompt: generate_data.prompt || generate_data.input,
mesId: getNextMessageId(type), mesId: getNextMessageId(type),
allAnchors: getAllExtensionPrompts(), allAnchors: await getAllExtensionPrompts(),
chatInjects: injectedIndices?.map(index => arrMes[arrMes.length - index - 1])?.join('') || '', chatInjects: injectedIndices?.map(index => arrMes[arrMes.length - index - 1])?.join('') || '',
summarizeString: (extension_prompts['1_memory']?.value || ''), summarizeString: (extension_prompts['1_memory']?.value || ''),
authorsNoteString: (extension_prompts['2_floating_prompt']?.value || ''), authorsNoteString: (extension_prompts['2_floating_prompt']?.value || ''),
@ -4742,9 +4782,9 @@ export function stopGeneration() {
* Injects extension prompts into chat messages. * Injects extension prompts into chat messages.
* @param {object[]} messages Array of chat messages * @param {object[]} messages Array of chat messages
* @param {boolean} isContinue Whether the generation is a continuation. If true, the extension prompts of depth 0 are injected at position 1. * @param {boolean} isContinue Whether the generation is a continuation. If true, the extension prompts of depth 0 are injected at position 1.
* @returns {number[]} Array of indices where the extension prompts were injected * @returns {Promise<number[]>} Array of indices where the extension prompts were injected
*/ */
function doChatInject(messages, isContinue) { async function doChatInject(messages, isContinue) {
const injectedIndices = []; const injectedIndices = [];
let totalInsertedMessages = 0; let totalInsertedMessages = 0;
messages.reverse(); messages.reverse();
@ -4762,7 +4802,7 @@ function doChatInject(messages, isContinue) {
const wrap = false; const wrap = false;
for (const role of roles) { for (const role of roles) {
const extensionPrompt = String(getExtensionPrompt(extension_prompt_types.IN_CHAT, i, separator, role, wrap)).trimStart(); const extensionPrompt = String(await getExtensionPrompt(extension_prompt_types.IN_CHAT, i, separator, role, wrap)).trimStart();
const isNarrator = role === extension_prompt_roles.SYSTEM; const isNarrator = role === extension_prompt_roles.SYSTEM;
const isUser = role === extension_prompt_roles.USER; const isUser = role === extension_prompt_roles.USER;
const name = names[role]; const name = names[role];
@ -7455,14 +7495,16 @@ function select_rm_characters() {
* @param {number} depth Insertion depth. 0 represets the last message in context. Expected values up to MAX_INJECTION_DEPTH. * @param {number} depth Insertion depth. 0 represets the last message in context. Expected values up to MAX_INJECTION_DEPTH.
* @param {number} role Extension prompt role. Defaults to SYSTEM. * @param {number} role Extension prompt role. Defaults to SYSTEM.
* @param {boolean} scan Should the prompt be included in the world info scan. * @param {boolean} scan Should the prompt be included in the world info scan.
* @param {(function(): Promise<boolean>|boolean)} filter Filter function to determine if the prompt should be injected.
*/ */
export function setExtensionPrompt(key, value, position, depth, scan = false, role = extension_prompt_roles.SYSTEM) { export function setExtensionPrompt(key, value, position, depth, scan = false, role = extension_prompt_roles.SYSTEM, filter = null) {
extension_prompts[key] = { extension_prompts[key] = {
value: String(value), value: String(value),
position: Number(position), position: Number(position),
depth: Number(depth), depth: Number(depth),
scan: !!scan, scan: !!scan,
role: Number(role ?? extension_prompt_roles.SYSTEM), role: Number(role ?? extension_prompt_roles.SYSTEM),
filter: filter,
}; };
} }

View File

@ -611,8 +611,9 @@ function formatWorldInfo(value) {
* *
* @param {Prompt[]} prompts - Array containing injection prompts. * @param {Prompt[]} prompts - Array containing injection prompts.
* @param {Object[]} messages - Array containing all messages. * @param {Object[]} messages - Array containing all messages.
* @returns {Promise<Object[]>} - Array containing all messages with injections.
*/ */
function populationInjectionPrompts(prompts, messages) { async function populationInjectionPrompts(prompts, messages) {
let totalInsertedMessages = 0; let totalInsertedMessages = 0;
const roleTypes = { const roleTypes = {
@ -635,7 +636,7 @@ function populationInjectionPrompts(prompts, messages) {
// Get prompts for current role // Get prompts for current role
const rolePrompts = depthPrompts.filter(prompt => prompt.role === role).map(x => x.content).join(separator); const rolePrompts = depthPrompts.filter(prompt => prompt.role === role).map(x => x.content).join(separator);
// Get extension prompt // Get extension prompt
const extensionPrompt = getExtensionPrompt(extension_prompt_types.IN_CHAT, i, separator, roleTypes[role], wrap); const extensionPrompt = await getExtensionPrompt(extension_prompt_types.IN_CHAT, i, separator, roleTypes[role], wrap);
const jointPrompt = [rolePrompts, extensionPrompt].filter(x => x).map(x => x.trim()).join(separator); const jointPrompt = [rolePrompts, extensionPrompt].filter(x => x).map(x => x.trim()).join(separator);
@ -1020,7 +1021,7 @@ async function populateChatCompletion(prompts, chatCompletion, { bias, quietProm
} }
// Add in-chat injections // Add in-chat injections
messages = populationInjectionPrompts(absolutePrompts, messages); messages = await populationInjectionPrompts(absolutePrompts, messages);
// Decide whether dialogue examples should always be added // Decide whether dialogue examples should always be added
if (power_user.pin_examples) { if (power_user.pin_examples) {
@ -1051,9 +1052,9 @@ async function populateChatCompletion(prompts, chatCompletion, { bias, quietProm
* @param {string} options.systemPromptOverride * @param {string} options.systemPromptOverride
* @param {string} options.jailbreakPromptOverride * @param {string} options.jailbreakPromptOverride
* @param {string} options.personaDescription * @param {string} options.personaDescription
* @returns {Object} prompts - The prepared and merged system and user-defined prompts. * @returns {Promise<Object>} prompts - The prepared and merged system and user-defined prompts.
*/ */
function preparePromptsForChatCompletion({ Scenario, charPersonality, name2, worldInfoBefore, worldInfoAfter, charDescription, quietPrompt, bias, extensionPrompts, systemPromptOverride, jailbreakPromptOverride, personaDescription }) { async function preparePromptsForChatCompletion({ Scenario, charPersonality, name2, worldInfoBefore, worldInfoAfter, charDescription, quietPrompt, bias, extensionPrompts, systemPromptOverride, jailbreakPromptOverride, personaDescription }) {
const scenarioText = Scenario && oai_settings.scenario_format ? substituteParams(oai_settings.scenario_format) : ''; const scenarioText = Scenario && oai_settings.scenario_format ? substituteParams(oai_settings.scenario_format) : '';
const charPersonalityText = charPersonality && oai_settings.personality_format ? substituteParams(oai_settings.personality_format) : ''; const charPersonalityText = charPersonality && oai_settings.personality_format ? substituteParams(oai_settings.personality_format) : '';
const groupNudge = substituteParams(oai_settings.group_nudge_prompt); const groupNudge = substituteParams(oai_settings.group_nudge_prompt);
@ -1142,6 +1143,9 @@ function preparePromptsForChatCompletion({ Scenario, charPersonality, name2, wor
if (!extensionPrompts[key].value) continue; if (!extensionPrompts[key].value) continue;
if (![extension_prompt_types.BEFORE_PROMPT, extension_prompt_types.IN_PROMPT].includes(prompt.position)) continue; if (![extension_prompt_types.BEFORE_PROMPT, extension_prompt_types.IN_PROMPT].includes(prompt.position)) continue;
const hasFilter = typeof prompt.filter === 'function';
if (hasFilter && !await prompt.filter()) continue;
systemPrompts.push({ systemPrompts.push({
identifier: key.replace(/\W/g, '_'), identifier: key.replace(/\W/g, '_'),
position: getPromptPosition(prompt.position), position: getPromptPosition(prompt.position),
@ -1252,7 +1256,7 @@ export async function prepareOpenAIMessages({
try { try {
// Merge markers and ordered user prompts with system prompts // Merge markers and ordered user prompts with system prompts
const prompts = preparePromptsForChatCompletion({ const prompts = await preparePromptsForChatCompletion({
Scenario, Scenario,
charPersonality, charPersonality,
name2, name2,

View File

@ -3721,7 +3721,7 @@ export async function checkWorldInfo(chat, maxContext, isDryRun) {
// Put this code here since otherwise, the chat reference is modified // Put this code here since otherwise, the chat reference is modified
for (const key of Object.keys(context.extensionPrompts)) { for (const key of Object.keys(context.extensionPrompts)) {
if (context.extensionPrompts[key]?.scan) { if (context.extensionPrompts[key]?.scan) {
const prompt = getExtensionPromptByName(key); const prompt = await getExtensionPromptByName(key);
if (prompt) { if (prompt) {
buffer.addInject(prompt); buffer.addInject(prompt);
} }