From 34bca69950c88f06bae145f18bcb7dc4d1207e89 Mon Sep 17 00:00:00 2001 From: based Date: Thu, 11 Apr 2024 16:38:20 +1000 Subject: [PATCH 1/4] system prompt support for gemini 1.5 --- public/index.html | 11 +++++++++++ public/scripts/openai.js | 12 ++++++++++++ src/endpoints/backends/chat-completions.js | 16 ++++++++++++---- src/endpoints/tokenizers.js | 2 +- src/prompt-converters.js | 17 ++++++++++++++--- 5 files changed, 50 insertions(+), 8 deletions(-) diff --git a/public/index.html b/public/index.html index 46942e6d9..164b1cd0e 100644 --- a/public/index.html +++ b/public/index.html @@ -1761,6 +1761,17 @@
Use the appropriate tokenizer for Google models via their API. Slower prompt processing, but offers much more accurate token counting.
+ +
+ + Merges all system messages up until the first message with a non system role, and sends them through google's system_instruction field instead of with the rest of the prompt contents. + +
diff --git a/public/scripts/openai.js b/public/scripts/openai.js index 98fa51b3d..8a71b89f0 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -260,6 +260,7 @@ const default_settings = { use_ai21_tokenizer: false, use_google_tokenizer: false, claude_use_sysprompt: false, + use_makersuite_sysprompt: true, use_alt_scale: false, squash_system_messages: false, image_inlining: false, @@ -330,6 +331,7 @@ const oai_settings = { use_ai21_tokenizer: false, use_google_tokenizer: false, claude_use_sysprompt: false, + use_makersuite_sysprompt: true, use_alt_scale: false, squash_system_messages: false, image_inlining: false, @@ -1733,6 +1735,7 @@ async function sendOpenAIRequest(type, messages, signal) { const stopStringsLimit = 3; // 5 - 2 (nameStopString and new_chat_prompt) generate_data['top_k'] = Number(oai_settings.top_k_openai); generate_data['stop'] = [nameStopString, substituteParams(oai_settings.new_chat_prompt), ...getCustomStoppingStrings(stopStringsLimit)]; + generate_data['use_makersuite_sysprompt'] = oai_settings.use_makersuite_sysprompt; } if (isAI21) { @@ -2668,6 +2671,7 @@ function loadOpenAISettings(data, settings) { if (settings.use_ai21_tokenizer !== undefined) { oai_settings.use_ai21_tokenizer = !!settings.use_ai21_tokenizer; oai_settings.use_ai21_tokenizer ? ai21_max = 8191 : ai21_max = 9200; } if (settings.use_google_tokenizer !== undefined) oai_settings.use_google_tokenizer = !!settings.use_google_tokenizer; if (settings.claude_use_sysprompt !== undefined) oai_settings.claude_use_sysprompt = !!settings.claude_use_sysprompt; + if (settings.use_makersuite_sysprompt !== undefined) oai_settings.use_makersuite_sysprompt = !!settings.use_makersuite_sysprompt; if (settings.use_alt_scale !== undefined) { oai_settings.use_alt_scale = !!settings.use_alt_scale; updateScaleForm(); } $('#stream_toggle').prop('checked', oai_settings.stream_openai); $('#api_url_scale').val(oai_settings.api_url_scale); @@ -2707,6 +2711,7 @@ function loadOpenAISettings(data, settings) { $('#use_ai21_tokenizer').prop('checked', oai_settings.use_ai21_tokenizer); $('#use_google_tokenizer').prop('checked', oai_settings.use_google_tokenizer); $('#claude_use_sysprompt').prop('checked', oai_settings.claude_use_sysprompt); + $('#use_makersuite_sysprompt').prop('checked', oai_settings.use_makersuite_sysprompt); $('#scale-alt').prop('checked', oai_settings.use_alt_scale); $('#openrouter_use_fallback').prop('checked', oai_settings.openrouter_use_fallback); $('#openrouter_force_instruct').prop('checked', oai_settings.openrouter_force_instruct); @@ -2976,6 +2981,7 @@ async function saveOpenAIPreset(name, settings, triggerUi = true) { use_ai21_tokenizer: settings.use_ai21_tokenizer, use_google_tokenizer: settings.use_google_tokenizer, claude_use_sysprompt: settings.claude_use_sysprompt, + use_makersuite_sysprompt: settings.use_makersuite_sysprompt, use_alt_scale: settings.use_alt_scale, squash_system_messages: settings.squash_system_messages, image_inlining: settings.image_inlining, @@ -3354,6 +3360,7 @@ function onSettingsPresetChange() { use_ai21_tokenizer: ['#use_ai21_tokenizer', 'use_ai21_tokenizer', true], use_google_tokenizer: ['#use_google_tokenizer', 'use_google_tokenizer', true], claude_use_sysprompt: ['#claude_use_sysprompt', 'claude_use_sysprompt', true], + use_makersuite_sysprompt: ['#use_makersuite_sysprompt', 'use_makersuite_sysprompt', true], use_alt_scale: ['#use_alt_scale', 'use_alt_scale', true], squash_system_messages: ['#squash_system_messages', 'squash_system_messages', true], image_inlining: ['#openai_image_inlining', 'image_inlining', true], @@ -4290,6 +4297,11 @@ $(document).ready(async function () { saveSettingsDebounced(); }); + $('#use_makersuite_sysprompt').on('change', function () { + oai_settings.use_makersuite_sysprompt = !!$('#use_makersuite_sysprompt').prop('checked'); + saveSettingsDebounced(); + }); + $('#send_if_empty_textarea').on('input', function () { oai_settings.send_if_empty = String($('#send_if_empty_textarea').val()); saveSettingsDebounced(); diff --git a/src/endpoints/backends/chat-completions.js b/src/endpoints/backends/chat-completions.js index b1613fbc2..3f1fcf610 100644 --- a/src/endpoints/backends/chat-completions.js +++ b/src/endpoints/backends/chat-completions.js @@ -252,17 +252,25 @@ async function sendMakerSuiteRequest(request, response) { }; function getGeminiBody() { - return { - contents: convertGooglePrompt(request.body.messages, model), + let should_use_system_prompt = model === 'gemini-1.5-pro-latest' && request.body.use_makersuite_sysprompt; + let prompt = convertGooglePrompt(request.body.messages, model, should_use_system_prompt); + let body = { + contents: prompt.contents, safetySettings: GEMINI_SAFETY, generationConfig: generationConfig, - }; + } + + if (should_use_system_prompt) { + body.system_instruction = prompt.system_instruction; + } + + return body; } function getBisonBody() { const prompt = isText ? ({ text: convertTextCompletionPrompt(request.body.messages) }) - : ({ messages: convertGooglePrompt(request.body.messages, model) }); + : ({ messages: convertGooglePrompt(request.body.messages, model).contents }); /** @type {any} Shut the lint up */ const bisonBody = { diff --git a/src/endpoints/tokenizers.js b/src/endpoints/tokenizers.js index e6fba800a..10cce1c2a 100644 --- a/src/endpoints/tokenizers.js +++ b/src/endpoints/tokenizers.js @@ -398,7 +398,7 @@ router.post('/google/count', jsonParser, async function (req, res) { accept: 'application/json', 'content-type': 'application/json', }, - body: JSON.stringify({ contents: convertGooglePrompt(req.body, String(req.query.model)) }), + body: JSON.stringify({ contents: convertGooglePrompt(req.body, String(req.query.model)).contents }), }; try { const response = await fetch(`https://generativelanguage.googleapis.com/v1beta/models/${req.query.model}:countTokens?key=${readSecret(SECRET_KEYS.MAKERSUITE)}`, options); diff --git a/src/prompt-converters.js b/src/prompt-converters.js index 72b75e223..517cb5e89 100644 --- a/src/prompt-converters.js +++ b/src/prompt-converters.js @@ -252,9 +252,10 @@ function convertCohereMessages(messages, charName = '', userName = '') { * Convert a prompt from the ChatML objects to the format used by Google MakerSuite models. * @param {object[]} messages Array of messages * @param {string} model Model name - * @returns {object[]} Prompt for Google MakerSuite models + * @param {boolean} useSysPrompt Use system prompt + * @returns {{contents: *[], system_instruction: {parts: {text: string}}}} Prompt for Google MakerSuite models */ -function convertGooglePrompt(messages, model) { +function convertGooglePrompt(messages, model, useSysPrompt = false) { // This is a 1x1 transparent PNG const PNG_PIXEL = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII='; @@ -267,6 +268,16 @@ function convertGooglePrompt(messages, model) { const isMultimodal = visionSupportedModels.includes(model); let hasImage = false; + let sys_prompt = ''; + if (useSysPrompt) { + while (messages.length > 1 && messages[0].role === 'system') { + sys_prompt += `${messages[0].content}\n\n`; + messages.shift(); + } + } + + const system_instruction = { parts: { text: sys_prompt }}; + const contents = []; messages.forEach((message, index) => { // fix the roles @@ -327,7 +338,7 @@ function convertGooglePrompt(messages, model) { }); } - return contents; + return { contents: contents, system_instruction: system_instruction }; } /** From c4ec97aa5053dc64ced51fd3cda4336150e9445f Mon Sep 17 00:00:00 2001 From: based Date: Thu, 11 Apr 2024 16:51:05 +1000 Subject: [PATCH 2/4] cleanup --- src/endpoints/backends/chat-completions.js | 4 ++-- src/prompt-converters.js | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/endpoints/backends/chat-completions.js b/src/endpoints/backends/chat-completions.js index 3f1fcf610..c3e2b5ea7 100644 --- a/src/endpoints/backends/chat-completions.js +++ b/src/endpoints/backends/chat-completions.js @@ -252,8 +252,8 @@ async function sendMakerSuiteRequest(request, response) { }; function getGeminiBody() { - let should_use_system_prompt = model === 'gemini-1.5-pro-latest' && request.body.use_makersuite_sysprompt; - let prompt = convertGooglePrompt(request.body.messages, model, should_use_system_prompt); + const should_use_system_prompt = model === 'gemini-1.5-pro-latest' && request.body.use_makersuite_sysprompt; + const prompt = convertGooglePrompt(request.body.messages, model, should_use_system_prompt); let body = { contents: prompt.contents, safetySettings: GEMINI_SAFETY, diff --git a/src/prompt-converters.js b/src/prompt-converters.js index 517cb5e89..b2fb937cb 100644 --- a/src/prompt-converters.js +++ b/src/prompt-converters.js @@ -276,7 +276,7 @@ function convertGooglePrompt(messages, model, useSysPrompt = false) { } } - const system_instruction = { parts: { text: sys_prompt }}; + const system_instruction = { parts: { text: sys_prompt.trim() }}; const contents = []; messages.forEach((message, index) => { From 4ac6bbd515610dd87e005159c8435f734a044e49 Mon Sep 17 00:00:00 2001 From: based Date: Thu, 11 Apr 2024 17:01:19 +1000 Subject: [PATCH 3/4] thought it looked a little strange --- public/index.html | 2 ++ 1 file changed, 2 insertions(+) diff --git a/public/index.html b/public/index.html index 164b1cd0e..b7c03a37a 100644 --- a/public/index.html +++ b/public/index.html @@ -1761,6 +1761,8 @@
Use the appropriate tokenizer for Google models via their API. Slower prompt processing, but offers much more accurate token counting.
+
+