From cde9903fcb603e58024780cb08b4546e792dbf25 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Thu, 14 Dec 2023 22:18:34 +0200 Subject: [PATCH] Fix Bison models --- public/script.js | 2 +- public/scripts/openai.js | 2 +- src/constants.js | 32 ++++++++++- src/endpoints/backends/chat-completions.js | 66 ++++++++++++++++++---- src/endpoints/google.js | 4 +- 5 files changed, 89 insertions(+), 17 deletions(-) diff --git a/public/script.js b/public/script.js index b5a7daf18..7d4ea3a37 100644 --- a/public/script.js +++ b/public/script.js @@ -2558,7 +2558,7 @@ function getCharacterCardFields() { function isStreamingEnabled() { const noStreamSources = [chat_completion_sources.SCALE, chat_completion_sources.AI21]; - return ((main_api == 'openai' && oai_settings.stream_openai && !noStreamSources.includes(oai_settings.chat_completion_source)) + return ((main_api == 'openai' && oai_settings.stream_openai && !noStreamSources.includes(oai_settings.chat_completion_source) && !(oai_settings.chat_completion_source == chat_completion_sources.MAKERSUITE && oai_settings.google_model.includes('bison'))) || (main_api == 'kobold' && kai_settings.streaming_kobold && kai_flags.can_use_streaming) || (main_api == 'novel' && nai_settings.streaming_novel) || (main_api == 'textgenerationwebui' && textgen_settings.streaming)); diff --git a/public/scripts/openai.js b/public/scripts/openai.js index db0d7c06d..9007017f6 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -1452,7 +1452,7 @@ async function sendOpenAIRequest(type, messages, signal) { const isQuiet = type === 'quiet'; const isImpersonate = type === 'impersonate'; const isContinue = type === 'continue'; - const stream = oai_settings.stream_openai && !isQuiet && !isScale && !isAI21; + const stream = oai_settings.stream_openai && !isQuiet && !isScale && !isAI21 && !(isGoogle && oai_settings.google_model.includes('bison')); if (isTextCompletion && isOpenRouter) { messages = convertChatCompletionToInstruct(messages, type); diff --git a/src/constants.js b/src/constants.js index 7151ada24..92af44cf2 100644 --- a/src/constants.js +++ b/src/constants.js @@ -105,7 +105,7 @@ const UNSAFE_EXTENSIONS = [ '.ws', ]; -const MAKERSUITE_SAFETY = [ +const GEMINI_SAFETY = [ { category: 'HARM_CATEGORY_HARASSMENT', threshold: 'BLOCK_NONE', @@ -124,6 +124,33 @@ const MAKERSUITE_SAFETY = [ }, ]; +const BISON_SAFETY = [ + { + category: 'HARM_CATEGORY_DEROGATORY', + threshold: 'BLOCK_NONE', + }, + { + category: 'HARM_CATEGORY_TOXICITY', + threshold: 'BLOCK_NONE', + }, + { + category: 'HARM_CATEGORY_VIOLENCE', + threshold: 'BLOCK_NONE', + }, + { + category: 'HARM_CATEGORY_SEXUAL', + threshold: 'BLOCK_NONE', + }, + { + category: 'HARM_CATEGORY_MEDICAL', + threshold: 'BLOCK_NONE', + }, + { + category: 'HARM_CATEGORY_DANGEROUS', + threshold: 'BLOCK_NONE', + }, +]; + const CHAT_COMPLETION_SOURCES = { OPENAI: 'openai', WINDOWAI: 'windowai', @@ -152,7 +179,8 @@ module.exports = { DIRECTORIES, UNSAFE_EXTENSIONS, UPLOADS_PATH, - MAKERSUITE_SAFETY, + GEMINI_SAFETY, + BISON_SAFETY, TEXTGEN_TYPES, CHAT_COMPLETION_SOURCES, AVATAR_WIDTH, diff --git a/src/endpoints/backends/chat-completions.js b/src/endpoints/backends/chat-completions.js index bd8969ac2..13e09cd56 100644 --- a/src/endpoints/backends/chat-completions.js +++ b/src/endpoints/backends/chat-completions.js @@ -3,7 +3,7 @@ const fetch = require('node-fetch').default; const { Readable } = require('stream'); const { jsonParser } = require('../../express-common'); -const { CHAT_COMPLETION_SOURCES, MAKERSUITE_SAFETY } = require('../../constants'); +const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY } = require('../../constants'); const { forwardFetchResponse, getConfigValue, tryParse, uuidv4 } = require('../../util'); const { convertClaudePrompt, convertGooglePrompt, convertTextCompletionPrompt } = require('../prompt-converters'); @@ -160,8 +160,10 @@ async function sendMakerSuiteRequest(request, response) { return response.status(400).send({ error: true }); } - const model = request.body.model; - const stream = request.body.stream; + const model = String(request.body.model); + const isGemini = model.includes('gemini'); + const isText = model.includes('text'); + const stream = Boolean(request.body.stream) && isGemini; const generationConfig = { stopSequences: request.body.stop, @@ -172,11 +174,48 @@ async function sendMakerSuiteRequest(request, response) { topK: request.body.top_k || undefined, }; - const body = { - contents: convertGooglePrompt(request.body.messages, model), - safetySettings: MAKERSUITE_SAFETY, - generationConfig: generationConfig, - }; + function getGeminiBody() { + return { + contents: convertGooglePrompt(request.body.messages, model), + safetySettings: GEMINI_SAFETY, + generationConfig: generationConfig, + }; + } + + function getBisonBody() { + const prompt = isText + ? ({ text: convertTextCompletionPrompt(request.body.messages) }) + : ({ messages: convertGooglePrompt(request.body.messages, model) }); + + /** @type {any} Shut the lint up */ + const bisonBody = { + ...generationConfig, + safetySettings: BISON_SAFETY, + candidate_count: 1, // lewgacy spelling + prompt: prompt, + }; + + if (!isText) { + delete bisonBody.stopSequences; + delete bisonBody.maxOutputTokens; + delete bisonBody.safetySettings; + + if (Array.isArray(prompt.messages)) { + for (const msg of prompt.messages) { + msg.author = msg.role; + msg.content = msg.parts[0].text; + delete msg.parts; + delete msg.role; + } + } + } + + delete bisonBody.candidateCount; + return bisonBody; + } + + const body = isGemini ? getGeminiBody() : getBisonBody(); + console.log('MakerSuite request:', body); try { const controller = new AbortController(); @@ -185,7 +224,12 @@ async function sendMakerSuiteRequest(request, response) { controller.abort(); }); - const generateResponse = await fetch(`https://generativelanguage.googleapis.com/v1beta/models/${model}:${stream ? 'streamGenerateContent' : 'generateContent'}?key=${apiKey}`, { + const apiVersion = isGemini ? 'v1beta' : 'v1beta2'; + const responseType = isGemini + ? (stream ? 'streamGenerateContent' : 'generateContent') + : (isText ? 'generateText' : 'generateMessage'); + + const generateResponse = await fetch(`https://generativelanguage.googleapis.com/${apiVersion}/models/${model}:${responseType}?key=${apiKey}`, { body: JSON.stringify(body), method: 'POST', headers: { @@ -251,8 +295,8 @@ async function sendMakerSuiteRequest(request, response) { return response.send({ error: { message } }); } - const responseContent = candidates[0].content; - const responseText = responseContent.parts[0].text; + const responseContent = candidates[0].content ?? candidates[0].output; + const responseText = typeof responseContent === 'string' ? responseContent : responseContent.parts?.[0]?.text; if (!responseText) { let message = 'MakerSuite Candidate text empty'; console.log(message, generateResponseJson); diff --git a/src/endpoints/google.js b/src/endpoints/google.js index 1e74f71c7..010b6f0ea 100644 --- a/src/endpoints/google.js +++ b/src/endpoints/google.js @@ -2,7 +2,7 @@ const { readSecret, SECRET_KEYS } = require('./secrets'); const fetch = require('node-fetch').default; const express = require('express'); const { jsonParser } = require('../express-common'); -const { MAKERSUITE_SAFETY } = require('../constants'); +const { GEMINI_SAFETY } = require('../constants'); const router = express.Router(); @@ -22,7 +22,7 @@ router.post('/caption-image', jsonParser, async (request, response) => { }, }], }], - safetySettings: MAKERSUITE_SAFETY, + safetySettings: GEMINI_SAFETY, generationConfig: { maxOutputTokens: 1000 }, };