Fix Bison models

This commit is contained in:
Cohee
2023-12-14 22:18:34 +02:00
parent eec28469f8
commit cde9903fcb
5 changed files with 89 additions and 17 deletions

View File

@ -2558,7 +2558,7 @@ function getCharacterCardFields() {
function isStreamingEnabled() { function isStreamingEnabled() {
const noStreamSources = [chat_completion_sources.SCALE, chat_completion_sources.AI21]; const noStreamSources = [chat_completion_sources.SCALE, chat_completion_sources.AI21];
return ((main_api == 'openai' && oai_settings.stream_openai && !noStreamSources.includes(oai_settings.chat_completion_source)) return ((main_api == 'openai' && oai_settings.stream_openai && !noStreamSources.includes(oai_settings.chat_completion_source) && !(oai_settings.chat_completion_source == chat_completion_sources.MAKERSUITE && oai_settings.google_model.includes('bison')))
|| (main_api == 'kobold' && kai_settings.streaming_kobold && kai_flags.can_use_streaming) || (main_api == 'kobold' && kai_settings.streaming_kobold && kai_flags.can_use_streaming)
|| (main_api == 'novel' && nai_settings.streaming_novel) || (main_api == 'novel' && nai_settings.streaming_novel)
|| (main_api == 'textgenerationwebui' && textgen_settings.streaming)); || (main_api == 'textgenerationwebui' && textgen_settings.streaming));

View File

@ -1452,7 +1452,7 @@ async function sendOpenAIRequest(type, messages, signal) {
const isQuiet = type === 'quiet'; const isQuiet = type === 'quiet';
const isImpersonate = type === 'impersonate'; const isImpersonate = type === 'impersonate';
const isContinue = type === 'continue'; const isContinue = type === 'continue';
const stream = oai_settings.stream_openai && !isQuiet && !isScale && !isAI21; const stream = oai_settings.stream_openai && !isQuiet && !isScale && !isAI21 && !(isGoogle && oai_settings.google_model.includes('bison'));
if (isTextCompletion && isOpenRouter) { if (isTextCompletion && isOpenRouter) {
messages = convertChatCompletionToInstruct(messages, type); messages = convertChatCompletionToInstruct(messages, type);

View File

@ -105,7 +105,7 @@ const UNSAFE_EXTENSIONS = [
'.ws', '.ws',
]; ];
const MAKERSUITE_SAFETY = [ const GEMINI_SAFETY = [
{ {
category: 'HARM_CATEGORY_HARASSMENT', category: 'HARM_CATEGORY_HARASSMENT',
threshold: 'BLOCK_NONE', threshold: 'BLOCK_NONE',
@ -124,6 +124,33 @@ const MAKERSUITE_SAFETY = [
}, },
]; ];
const BISON_SAFETY = [
{
category: 'HARM_CATEGORY_DEROGATORY',
threshold: 'BLOCK_NONE',
},
{
category: 'HARM_CATEGORY_TOXICITY',
threshold: 'BLOCK_NONE',
},
{
category: 'HARM_CATEGORY_VIOLENCE',
threshold: 'BLOCK_NONE',
},
{
category: 'HARM_CATEGORY_SEXUAL',
threshold: 'BLOCK_NONE',
},
{
category: 'HARM_CATEGORY_MEDICAL',
threshold: 'BLOCK_NONE',
},
{
category: 'HARM_CATEGORY_DANGEROUS',
threshold: 'BLOCK_NONE',
},
];
const CHAT_COMPLETION_SOURCES = { const CHAT_COMPLETION_SOURCES = {
OPENAI: 'openai', OPENAI: 'openai',
WINDOWAI: 'windowai', WINDOWAI: 'windowai',
@ -152,7 +179,8 @@ module.exports = {
DIRECTORIES, DIRECTORIES,
UNSAFE_EXTENSIONS, UNSAFE_EXTENSIONS,
UPLOADS_PATH, UPLOADS_PATH,
MAKERSUITE_SAFETY, GEMINI_SAFETY,
BISON_SAFETY,
TEXTGEN_TYPES, TEXTGEN_TYPES,
CHAT_COMPLETION_SOURCES, CHAT_COMPLETION_SOURCES,
AVATAR_WIDTH, AVATAR_WIDTH,

View File

@ -3,7 +3,7 @@ const fetch = require('node-fetch').default;
const { Readable } = require('stream'); const { Readable } = require('stream');
const { jsonParser } = require('../../express-common'); const { jsonParser } = require('../../express-common');
const { CHAT_COMPLETION_SOURCES, MAKERSUITE_SAFETY } = require('../../constants'); const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY } = require('../../constants');
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4 } = require('../../util'); const { forwardFetchResponse, getConfigValue, tryParse, uuidv4 } = require('../../util');
const { convertClaudePrompt, convertGooglePrompt, convertTextCompletionPrompt } = require('../prompt-converters'); const { convertClaudePrompt, convertGooglePrompt, convertTextCompletionPrompt } = require('../prompt-converters');
@ -160,8 +160,10 @@ async function sendMakerSuiteRequest(request, response) {
return response.status(400).send({ error: true }); return response.status(400).send({ error: true });
} }
const model = request.body.model; const model = String(request.body.model);
const stream = request.body.stream; const isGemini = model.includes('gemini');
const isText = model.includes('text');
const stream = Boolean(request.body.stream) && isGemini;
const generationConfig = { const generationConfig = {
stopSequences: request.body.stop, stopSequences: request.body.stop,
@ -172,11 +174,48 @@ async function sendMakerSuiteRequest(request, response) {
topK: request.body.top_k || undefined, topK: request.body.top_k || undefined,
}; };
const body = { function getGeminiBody() {
contents: convertGooglePrompt(request.body.messages, model), return {
safetySettings: MAKERSUITE_SAFETY, contents: convertGooglePrompt(request.body.messages, model),
generationConfig: generationConfig, safetySettings: GEMINI_SAFETY,
}; generationConfig: generationConfig,
};
}
function getBisonBody() {
const prompt = isText
? ({ text: convertTextCompletionPrompt(request.body.messages) })
: ({ messages: convertGooglePrompt(request.body.messages, model) });
/** @type {any} Shut the lint up */
const bisonBody = {
...generationConfig,
safetySettings: BISON_SAFETY,
candidate_count: 1, // lewgacy spelling
prompt: prompt,
};
if (!isText) {
delete bisonBody.stopSequences;
delete bisonBody.maxOutputTokens;
delete bisonBody.safetySettings;
if (Array.isArray(prompt.messages)) {
for (const msg of prompt.messages) {
msg.author = msg.role;
msg.content = msg.parts[0].text;
delete msg.parts;
delete msg.role;
}
}
}
delete bisonBody.candidateCount;
return bisonBody;
}
const body = isGemini ? getGeminiBody() : getBisonBody();
console.log('MakerSuite request:', body);
try { try {
const controller = new AbortController(); const controller = new AbortController();
@ -185,7 +224,12 @@ async function sendMakerSuiteRequest(request, response) {
controller.abort(); controller.abort();
}); });
const generateResponse = await fetch(`https://generativelanguage.googleapis.com/v1beta/models/${model}:${stream ? 'streamGenerateContent' : 'generateContent'}?key=${apiKey}`, { const apiVersion = isGemini ? 'v1beta' : 'v1beta2';
const responseType = isGemini
? (stream ? 'streamGenerateContent' : 'generateContent')
: (isText ? 'generateText' : 'generateMessage');
const generateResponse = await fetch(`https://generativelanguage.googleapis.com/${apiVersion}/models/${model}:${responseType}?key=${apiKey}`, {
body: JSON.stringify(body), body: JSON.stringify(body),
method: 'POST', method: 'POST',
headers: { headers: {
@ -251,8 +295,8 @@ async function sendMakerSuiteRequest(request, response) {
return response.send({ error: { message } }); return response.send({ error: { message } });
} }
const responseContent = candidates[0].content; const responseContent = candidates[0].content ?? candidates[0].output;
const responseText = responseContent.parts[0].text; const responseText = typeof responseContent === 'string' ? responseContent : responseContent.parts?.[0]?.text;
if (!responseText) { if (!responseText) {
let message = 'MakerSuite Candidate text empty'; let message = 'MakerSuite Candidate text empty';
console.log(message, generateResponseJson); console.log(message, generateResponseJson);

View File

@ -2,7 +2,7 @@ const { readSecret, SECRET_KEYS } = require('./secrets');
const fetch = require('node-fetch').default; const fetch = require('node-fetch').default;
const express = require('express'); const express = require('express');
const { jsonParser } = require('../express-common'); const { jsonParser } = require('../express-common');
const { MAKERSUITE_SAFETY } = require('../constants'); const { GEMINI_SAFETY } = require('../constants');
const router = express.Router(); const router = express.Router();
@ -22,7 +22,7 @@ router.post('/caption-image', jsonParser, async (request, response) => {
}, },
}], }],
}], }],
safetySettings: MAKERSUITE_SAFETY, safetySettings: GEMINI_SAFETY,
generationConfig: { maxOutputTokens: 1000 }, generationConfig: { maxOutputTokens: 1000 },
}; };