Add multimodal captioning for SD prompt generation

This commit is contained in:
Cohee
2023-11-19 15:24:43 +02:00
parent c3e5d0f6f2
commit b0b19edf31
4 changed files with 137 additions and 10 deletions

View File

@ -12,13 +12,18 @@ import {
getCurrentChatId,
animation_duration,
appendMediaToMessage,
getUserAvatar,
user_avatar,
getCharacterAvatar,
formatCharacterAvatar,
} from "../../../script.js";
import { getApiUrl, getContext, extension_settings, doExtrasFetch, modules, renderExtensionTemplate } from "../../extensions.js";
import { selected_group } from "../../group-chats.js";
import { stringFormat, initScrollHeight, resetScrollHeight, getCharaFilename, saveBase64AsFile } from "../../utils.js";
import { stringFormat, initScrollHeight, resetScrollHeight, getCharaFilename, saveBase64AsFile, getBase64Async } from "../../utils.js";
import { getMessageTimeStamp, humanizedDateTime } from "../../RossAscends-mods.js";
import { SECRET_KEYS, secret_state } from "../../secrets.js";
import { getNovelUnlimitedImageGeneration, getNovelAnlas, loadNovelSubscriptionData } from "../../nai-settings.js";
import { getMultimodalCaption } from "../shared.js";
export { MODULE_NAME };
// Wraps a string into monospace font-face span
@ -49,6 +54,15 @@ const generationMode = {
FACE: 5,
FREE: 6,
BACKGROUND: 7,
CHARACTER_MULTIMODAL: 8,
USER_MULTIMODAL: 9,
FACE_MULTIMODAL: 10,
}
const multimodalMap = {
[generationMode.CHARACTER]: generationMode.CHARACTER_MULTIMODAL,
[generationMode.USER]: generationMode.USER_MULTIMODAL,
[generationMode.FACE]: generationMode.FACE_MULTIMODAL,
}
const modeLabels = {
@ -59,6 +73,9 @@ const modeLabels = {
[generationMode.NOW]: 'Last Message',
[generationMode.RAW_LAST]: 'Raw Last Message',
[generationMode.BACKGROUND]: 'Background',
[generationMode.CHARACTER_MULTIMODAL]: 'Character (Multimodal Mode)',
[generationMode.FACE_MULTIMODAL]: 'Portrait (Multimodal Mode)',
[generationMode.USER_MULTIMODAL]: 'User (Multimodal Mode)',
}
const triggerWords = {
@ -118,6 +135,9 @@ const promptTemplates = {
[generationMode.RAW_LAST]: "[Pause your roleplay and provide ONLY the last chat message string back to me verbatim. Do not write anything after the string. Do not roleplay at all in your response. Do not continue the roleplay story.]",
[generationMode.BACKGROUND]: "[Pause your roleplay and provide a detailed description of {{char}}'s surroundings in the form of a comma-delimited list of keywords and phrases. The list must include all of the following items in this order: location, time of day, weather, lighting, and any other relevant details. Do not include descriptions of characters and non-visual qualities such as names, personality, movements, scents, mental traits, or anything which could not be seen in a still photograph. Do not write in full sentences. Prefix your description with the phrase 'background,'. Ignore the rest of the story when crafting this description. Do not roleplay as {{user}} when writing this description, and do not attempt to continue the story.]",
[generationMode.FACE_MULTIMODAL]: `Provide an exhaustive comma-separated list of tags describing the appearance of the character on this image in great detail. Start with "close-up portrait".`,
[generationMode.CHARACTER_MULTIMODAL]: `Provide an exhaustive comma-separated list of tags describing the appearance of the character on this image in great detail. Start with "full body portrait".`,
[generationMode.USER_MULTIMODAL]: `Provide an exhaustive comma-separated list of tags describing the appearance of the character on this image in great detail. Start with "full body portrait".`,
}
const helpString = [
@ -177,6 +197,7 @@ const defaultSettings = {
refine_mode: false,
expand: false,
interactive_mode: false,
multimodal_captioning: false,
prompts: promptTemplates,
@ -342,6 +363,7 @@ async function loadSettings() {
$('#sd_enable_hr').prop('checked', extension_settings.sd.enable_hr);
$('#sd_refine_mode').prop('checked', extension_settings.sd.refine_mode);
$('#sd_expand').prop('checked', extension_settings.sd.expand);
$('#sd_multimodal_captioning').prop('checked', extension_settings.sd.multimodal_captioning);
$('#sd_auto_url').val(extension_settings.sd.auto_url);
$('#sd_auto_auth').val(extension_settings.sd.auto_auth);
$('#sd_vlad_url').val(extension_settings.sd.vlad_url);
@ -401,6 +423,11 @@ function onInteractiveModeInput() {
saveSettingsDebounced();
}
function onMultimodalCaptioningInput() {
extension_settings.sd.multimodal_captioning = !!$(this).prop('checked');
saveSettingsDebounced();
}
function onStyleSelect() {
const selectedStyle = String($('#sd_style').find(':selected').val());
const styleObject = extension_settings.sd.styles.find(x => x.name === selectedStyle);
@ -1205,15 +1232,22 @@ async function loadNovelModels() {
}
function getGenerationType(prompt) {
let mode = generationMode.FREE;
for (const [key, values] of Object.entries(triggerWords)) {
for (const value of values) {
if (value.toLowerCase() === prompt.toLowerCase().trim()) {
return Number(key);
mode = Number(key);
break;
}
}
}
return generationMode.FREE;
if (extension_settings.sd.multimodal_captioning && multimodalMap[mode] !== undefined) {
mode = multimodalMap[mode];
}
return mode;
}
function getQuietPrompt(mode, trigger) {
@ -1284,7 +1318,7 @@ async function generatePicture(_, trigger, message, callback) {
trigger = trigger.trim();
const generationType = getGenerationType(trigger);
console.log('Generation mode', generationType, 'triggered with', trigger);
const quiet_prompt = getQuietPrompt(generationType, trigger);
const quietPrompt = getQuietPrompt(generationType, trigger);
const context = getContext();
// if context.characterId is not null, then we get context.characters[context.characterId].avatar, else we get groupId and context.groups[groupId].id
@ -1308,7 +1342,7 @@ async function generatePicture(_, trigger, message, callback) {
const dimensions = setTypeSpecificDimensions(generationType);
try {
const prompt = await getPrompt(generationType, message, trigger, quiet_prompt);
const prompt = await getPrompt(generationType, message, trigger, quietPrompt);
console.log('Processed image prompt:', prompt);
context.deactivateSendButtons();
@ -1353,7 +1387,7 @@ function restoreOriginalDimensions(savedParams) {
extension_settings.sd.width = savedParams.width;
}
async function getPrompt(generationType, message, trigger, quiet_prompt) {
async function getPrompt(generationType, message, trigger, quietPrompt) {
let prompt;
switch (generationType) {
@ -1363,8 +1397,13 @@ async function getPrompt(generationType, message, trigger, quiet_prompt) {
case generationMode.FREE:
prompt = trigger.trim();
break;
case generationMode.FACE_MULTIMODAL:
case generationMode.CHARACTER_MULTIMODAL:
case generationMode.USER_MULTIMODAL:
prompt = await generateMultimodalPrompt(generationType, quietPrompt);
break;
default:
prompt = await generatePrompt(quiet_prompt);
prompt = await generatePrompt(quietPrompt);
break;
}
@ -1375,8 +1414,57 @@ async function getPrompt(generationType, message, trigger, quiet_prompt) {
return prompt;
}
async function generatePrompt(quiet_prompt) {
const reply = await generateQuietPrompt(quiet_prompt, false, false);
/**
* Generates a prompt using multimodal captioning.
* @param {number} generationType - The type of image generation to perform.
* @param {string} quietPrompt - The prompt to use for the image generation.
*/
async function generateMultimodalPrompt(generationType, quietPrompt) {
let avatarUrl;
if (generationType == generationMode.USER_MULTIMODAL) {
avatarUrl = getUserAvatar(user_avatar);
}
if (generationType == generationMode.CHARACTER_MULTIMODAL || generationType === generationMode.FACE_MULTIMODAL) {
const context = getContext();
if (context.groupId) {
const groupMembers = context.groups.find(x => x.id === context.groupId)?.members;
const lastMessageAvatar = context.chat?.filter(x => !x.is_system && !x.is_user)?.slice(-1)[0]?.original_avatar;
const randomMemberAvatar = Array.isArray(groupMembers) ? groupMembers[Math.floor(Math.random() * groupMembers.length)]?.avatar : null;
const avatarToUse = lastMessageAvatar || randomMemberAvatar;
avatarUrl = formatCharacterAvatar(avatarToUse);
} else {
avatarUrl = getCharacterAvatar(context.characterId);
}
}
const response = await fetch(avatarUrl);
if (!response.ok) {
throw new Error('Could not fetch avatar image.');
}
const avatarBlob = await response.blob();
const avatarBase64 = await getBase64Async(avatarBlob);
const caption = await getMultimodalCaption(avatarBase64, quietPrompt);
if (!caption) {
throw new Error('Multimodal captioning failed.');
}
return caption;
}
/**
* Generates a prompt using the main LLM API.
* @param {string} quietPrompt - The prompt to use for the image generation.
* @returns {Promise<string>} - A promise that resolves when the prompt generation completes.
*/
async function generatePrompt(quietPrompt) {
const reply = await generateQuietPrompt(quietPrompt, false, false);
return processReply(reply);
}
@ -1932,6 +2020,7 @@ jQuery(async () => {
$('#sd_interactive_mode').on('input', onInteractiveModeInput);
$('#sd_openai_style').on('change', onOpenAiStyleSelect);
$('#sd_openai_quality').on('change', onOpenAiQualitySelect);
$('#sd_multimodal_captioning').on('input', onMultimodalCaptioningInput);
$('.sd_settings .inline-drawer-toggle').on('click', function () {
initScrollHeight($("#sd_prompt_prefix"));