Merge pull request #2805 from SillyTavern/gemini-vision-fix

Fix Gemini multimodal with JPG images
This commit is contained in:
Cohee 2024-09-08 17:51:20 +03:00 committed by GitHub
commit 5e522d6e35
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 33 additions and 13 deletions

View File

@ -20,7 +20,7 @@ export async function getMultimodalCaption(base64Img, prompt) {
throwIfInvalidModel(useReverseProxy);
const noPrefix = ['google', 'ollama', 'llamacpp'].includes(extension_settings.caption.multimodal_api);
const noPrefix = ['ollama', 'llamacpp'].includes(extension_settings.caption.multimodal_api);
if (noPrefix && base64Img.startsWith('data:image/')) {
base64Img = base64Img.split(',')[1];
@ -28,7 +28,6 @@ export async function getMultimodalCaption(base64Img, prompt) {
// OpenRouter has a payload limit of ~2MB. Google is 4MB, but we love democracy.
// Ooba requires all images to be JPEGs. Koboldcpp just asked nicely.
const isGoogle = extension_settings.caption.multimodal_api === 'google';
const isOllama = extension_settings.caption.multimodal_api === 'ollama';
const isLlamaCpp = extension_settings.caption.multimodal_api === 'llamacpp';
const isCustom = extension_settings.caption.multimodal_api === 'custom';
@ -40,10 +39,6 @@ export async function getMultimodalCaption(base64Img, prompt) {
if ((['google', 'openrouter'].includes(extension_settings.caption.multimodal_api) && base64Bytes > compressionLimit) || isOoba || isKoboldCpp) {
const maxSide = 1024;
base64Img = await createThumbnail(base64Img, maxSide, maxSide, 'image/jpeg');
if (isGoogle) {
base64Img = base64Img.split(',')[1];
}
}
const proxyUrl = useReverseProxy ? oai_settings.reverse_proxy : '';

View File

@ -47,6 +47,7 @@ import { SECRET_KEYS, secret_state, writeSecret } from './secrets.js';
import { getEventSourceStream } from './sse-stream.js';
import {
createThumbnail,
delay,
download,
getBase64Async,
@ -2440,15 +2441,14 @@ class Message {
if (!response.ok) throw new Error('Failed to fetch image');
const blob = await response.blob();
image = await getBase64Async(blob);
if (oai_settings.chat_completion_source === chat_completion_sources.MAKERSUITE) {
image = image.split(',')[1];
}
} catch (error) {
console.error('Image adding skipped', error);
return;
}
}
image = await this.compressImage(image);
const quality = oai_settings.inline_image_quality || default_settings.inline_image_quality;
this.content = [
{ type: 'text', text: textContent },
@ -2464,6 +2464,29 @@ class Message {
}
}
/**
* Compress an image if it exceeds the size threshold for the current chat completion source.
* @param {string} image Data URL of the image.
* @returns {Promise<string>} Compressed image as a Data URL.
*/
async compressImage(image) {
if ([chat_completion_sources.OPENROUTER, chat_completion_sources.MAKERSUITE].includes(oai_settings.chat_completion_source)) {
const sizeThreshold = 2 * 1024 * 1024;
const dataSize = image.length * 0.75;
const maxSide = 1024;
if (dataSize > sizeThreshold) {
image = await createThumbnail(image, maxSide);
}
}
return image;
}
/**
* Get the token cost of an image.
* @param {string} dataUrl Data URL of the image.
* @param {string} quality String representing the quality of the image. Can be 'low', 'auto', or 'high'.
* @returns {Promise<number>} The token cost of the image.
*/
async getImageTokenCost(dataUrl, quality) {
if (quality === 'low') {
return Message.tokensPerImage;

View File

@ -22,8 +22,8 @@ router.post('/caption-image', jsonParser, async (request, response) => {
{ text: request.body.prompt },
{
inlineData: {
mimeType: 'image/png', // It needs to specify a MIME type in data if it's not a PNG
data: mimeType === 'image/png' ? base64Data : request.body.image,
mimeType: mimeType,
data: base64Data,
},
}],
}],

View File

@ -335,10 +335,12 @@ function convertGooglePrompt(messages, model, useSysPrompt = false, charName = '
if (part.type === 'text') {
parts.push({ text: part.text });
} else if (part.type === 'image_url' && isMultimodal) {
const mimeType = part.image_url.url.split(';')[0].split(':')[1];
const base64Data = part.image_url.url.split(',')[1];
parts.push({
inlineData: {
mimeType: 'image/png',
data: part.image_url.url,
mimeType: mimeType,
data: base64Data,
},
});
hasImage = true;