From dc4a6e862b862dd1c6e1ec6974429e6fb302b214 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Tue, 12 Sep 2023 00:15:21 +0300 Subject: [PATCH] Add local caption pipeline to UI plugin --- default/config.conf | 2 + public/scripts/extensions/caption/index.js | 87 +++++++++++++++---- .../scripts/extensions/caption/manifest.json | 4 +- 3 files changed, 72 insertions(+), 21 deletions(-) diff --git a/default/config.conf b/default/config.conf index 5e9d2920f..87801a256 100644 --- a/default/config.conf +++ b/default/config.conf @@ -19,6 +19,8 @@ const securityOverride = false; const extras = { // Text classification model for sentiment analysis. HuggingFace ID of a model in ONNX format. classificationModel: 'Cohee/distilbert-base-uncased-go-emotions-onnx', + // Image captioning model. HuggingFace ID of a model in ONNX format. + captioningModel: 'Xenova/vit-gpt2-image-captioning', }; // Request overrides for additional headers diff --git a/public/scripts/extensions/caption/index.js b/public/scripts/extensions/caption/index.js index 6e12684eb..5e8f9a64a 100644 --- a/public/scripts/extensions/caption/index.js +++ b/public/scripts/extensions/caption/index.js @@ -1,6 +1,6 @@ -import { getBase64Async } from "../../utils.js"; -import { getContext, getApiUrl, doExtrasFetch, extension_settings } from "../../extensions.js"; -import { callPopup, saveSettingsDebounced } from "../../../script.js"; +import { getBase64Async, saveBase64AsFile } from "../../utils.js"; +import { getContext, getApiUrl, doExtrasFetch, extension_settings, modules } from "../../extensions.js"; +import { callPopup, getRequestHeaders, saveSettingsDebounced } from "../../../script.js"; import { getMessageTimeStamp } from "../../RossAscends-mods.js"; export { MODULE_NAME }; @@ -8,7 +8,8 @@ const MODULE_NAME = 'caption'; const UPDATE_INTERVAL = 1000; async function moduleWorker() { - $('#send_picture').toggle(getContext().onlineStatus !== 'no_connection'); + const hasConnection = getContext().onlineStatus !== 'no_connection'; + $('#send_picture').toggle(hasConnection); } async function setImageIcon() { @@ -65,16 +66,21 @@ async function sendCaptionedMessage(caption, image) { await context.generate('caption'); } -async function onSelectImage(e) { - setSpinnerIcon(); - const file = e.target.files[0]; +async function doCaptionRequest(base64Img) { + if (extension_settings.caption.local) { + const apiResult = await fetch('/api/extra/caption', { + method: 'POST', + headers: getRequestHeaders(), + body: JSON.stringify({ image: base64Img }) + }); - if (!file) { - return; - } + if (!apiResult.ok) { + throw new Error('Failed to caption image via local pipeline.'); + } - try { - const base64Img = await getBase64Async(file); + const data = await apiResult.json(); + return data; + } else if (modules.includes('caption')) { const url = new URL(getApiUrl()); url.pathname = '/api/caption'; @@ -84,17 +90,42 @@ async function onSelectImage(e) { 'Content-Type': 'application/json', 'Bypass-Tunnel-Reminder': 'bypass', }, - body: JSON.stringify({ image: base64Img.split(',')[1] }) + body: JSON.stringify({ image: base64Img }) }); - if (apiResult.ok) { - const data = await apiResult.json(); - const caption = data.caption; - const imageToSave = data.thumbnail ? `data:image/jpeg;base64,${data.thumbnail}` : base64Img; - await sendCaptionedMessage(caption, imageToSave); + if (!apiResult.ok) { + throw new Error('Failed to caption image via Extras.'); } + + const data = await apiResult.json(); + return data; + } else { + throw new Error('No captioning module is available.'); + } +} + +async function onSelectImage(e) { + setSpinnerIcon(); + const file = e.target.files[0]; + + if (!file || !(file instanceof File)) { + return; + } + + try { + const context = getContext(); + const fileData = await getBase64Async(file); + const base64Format = fileData.split(',')[0].split(';')[0].split('/')[1]; + const base64Data = fileData.split(',')[1]; + const data = await doCaptionRequest(base64Data); + const caption = data.caption; + const imageToSave = data.thumbnail ? data.thumbnail : base64Data; + const format = data.thumbnail ? 'jpeg' : base64Format; + const imagePath = await saveBase64AsFile(imageToSave, context.name2, '', format); + await sendCaptionedMessage(caption, imagePath); } catch (error) { + toastr.error('Failed to caption image.'); console.log(error); } finally { @@ -118,7 +149,16 @@ jQuery(function () { $('#extensionsMenu').prepend(sendButton); $(sendButton).hide(); - $(sendButton).on('click', () => $('#img_file').trigger('click')); + $(sendButton).on('click', () => { + const hasCaptionModule = modules.includes('caption') || extension_settings.caption.local; + + if (!hasCaptionModule) { + toastr.error('No captioning module is available. Either enable the local captioning pipeline or connect to Extras.'); + return; + } + + $('#img_file').trigger('click'); + }); } function addPictureSendForm() { const inputHtml = ``; @@ -138,6 +178,10 @@ jQuery(function () {
+