From 60880cfd4d5add7257fc8fa38ed36fe965a71d58 Mon Sep 17 00:00:00 2001 From: based Date: Fri, 15 Dec 2023 01:39:12 +1000 Subject: [PATCH] merge --- public/scripts/openai.js | 6 +- src/endpoints/backends/chat-completions.js | 128 +++++++++++++++------ 2 files changed, 93 insertions(+), 41 deletions(-) diff --git a/public/scripts/openai.js b/public/scripts/openai.js index 08c1cd386..b5b505e3f 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -31,7 +31,7 @@ import { system_message_types, this_chid, } from '../script.js'; -import {groups, selected_group} from './group-chats.js'; +import { groups, selected_group } from './group-chats.js'; import { chatCompletionDefaultPrompts, @@ -41,8 +41,8 @@ import { promptManagerDefaultPromptOrders, } from './PromptManager.js'; -import {getCustomStoppingStrings, persona_description_positions, power_user} from './power-user.js'; -import {SECRET_KEYS, secret_state, writeSecret} from './secrets.js'; +import { getCustomStoppingStrings, persona_description_positions, power_user } from './power-user.js'; +import { SECRET_KEYS, secret_state, writeSecret } from './secrets.js'; import EventSourceStream from './sse-stream.js'; import { diff --git a/src/endpoints/backends/chat-completions.js b/src/endpoints/backends/chat-completions.js index af463bd21..16b87ecd6 100644 --- a/src/endpoints/backends/chat-completions.js +++ b/src/endpoints/backends/chat-completions.js @@ -2,9 +2,9 @@ const express = require('express'); const fetch = require('node-fetch').default; const { jsonParser } = require('../../express-common'); -const { CHAT_COMPLETION_SOURCES, PALM_SAFETY } = require('../../constants'); +const { CHAT_COMPLETION_SOURCES, MAKERSUITE_SAFETY } = require('../../constants'); const { forwardFetchResponse, getConfigValue, tryParse, uuidv4 } = require('../../util'); -const { convertClaudePrompt, convertTextCompletionPrompt } = require('../prompt-converters'); +const { convertClaudePrompt, convertGooglePrompt, convertTextCompletionPrompt } = require('../prompt-converters'); const { readSecret, SECRET_KEYS } = require('../secrets'); const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sentencepieceTokenizers, TEXT_COMPLETION_MODELS } = require('../tokenizers'); @@ -151,28 +151,35 @@ async function sendScaleRequest(request, response) { * @param {express.Request} request Express request * @param {express.Response} response Express response */ -async function sendPalmRequest(request, response) { - const api_key_palm = readSecret(SECRET_KEYS.PALM); +/** + * @param {express.Request} request + * @param {express.Response} response + */ +async function sendMakerSuiteRequest(request, response) { + const api_key_makersuite = readSecret(SECRET_KEYS.MAKERSUITE); - if (!api_key_palm) { - console.log('Palm API key is missing.'); + if (!api_key_makersuite) { + console.log('MakerSuite API key is missing.'); return response.status(400).send({ error: true }); } - const body = { - prompt: { - text: request.body.messages, - }, + const google_model = request.body.model; + const should_stream = request.body.stream; + + const generationConfig = { stopSequences: request.body.stop, - safetySettings: PALM_SAFETY, + candidateCount: 1, + maxOutputTokens: request.body.max_tokens, temperature: request.body.temperature, topP: request.body.top_p, topK: request.body.top_k || undefined, - maxOutputTokens: request.body.max_tokens, - candidate_count: 1, }; - console.log('Palm request:', body); + const body = { + contents: convertGooglePrompt(request.body.messages, google_model), + safetySettings: MAKERSUITE_SAFETY, + generationConfig: generationConfig, + }; try { const controller = new AbortController(); @@ -181,7 +188,7 @@ async function sendPalmRequest(request, response) { controller.abort(); }); - const generateResponse = await fetch(`https://generativelanguage.googleapis.com/v1beta2/models/text-bison-001:generateText?key=${api_key_palm}`, { + const generateResponse = await fetch(`https://generativelanguage.googleapis.com/v1beta/models/${google_model}:${should_stream ? 'streamGenerateContent' : 'generateContent'}?key=${api_key_makersuite}`, { body: JSON.stringify(body), method: 'POST', headers: { @@ -190,34 +197,79 @@ async function sendPalmRequest(request, response) { signal: controller.signal, timeout: 0, }); + // have to do this because of their busted ass streaming endpoint + if (should_stream) { + try { + let partialData = ''; + generateResponse.body.on('data', (data) => { + const chunk = data.toString(); + if (chunk.startsWith(',') || chunk.endsWith(',') || chunk.startsWith('[') || chunk.endsWith(']')) { + partialData = chunk.slice(1); + } else { + partialData += chunk; + } + while (true) { + let json; + try { + json = JSON.parse(partialData); + } catch (e) { + break; + } + response.write(JSON.stringify(json)); + partialData = ''; + } + }); - if (!generateResponse.ok) { - console.log(`Palm API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`); - return response.status(generateResponse.status).send({ error: true }); - } + request.socket.on('close', function () { + generateResponse.body.destroy(); + response.end(); + }); - const generateResponseJson = await generateResponse.json(); - const responseText = generateResponseJson?.candidates?.[0]?.output; + generateResponse.body.on('end', () => { + console.log('Streaming request finished'); + response.end(); + }); - if (!responseText) { - console.log('Palm API returned no response', generateResponseJson); - let message = `Palm API returned no response: ${JSON.stringify(generateResponseJson)}`; - - // Check for filters - if (generateResponseJson?.filters?.[0]?.reason) { - message = `Palm filter triggered: ${generateResponseJson.filters[0].reason}`; + } catch (error) { + console.log('Error forwarding streaming response:', error); + if (!response.headersSent) { + return response.status(500).send({ error: true }); + } + } + } else { + if (!generateResponse.ok) { + console.log(`MakerSuite API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`); + return response.status(generateResponse.status).send({ error: true }); } - return response.send({ error: { message } }); + const generateResponseJson = await generateResponse.json(); + + const candidates = generateResponseJson?.candidates; + if (!candidates || candidates.length === 0) { + let message = 'MakerSuite API returned no candidate'; + console.log(message, generateResponseJson); + if (generateResponseJson?.promptFeedback?.blockReason) { + message += `\nPrompt was blocked due to : ${generateResponseJson.promptFeedback.blockReason}`; + } + return response.send({ error: { message } }); + } + + const responseContent = candidates[0].content; + const responseText = responseContent.parts[0].text; + if (!responseText) { + let message = 'MakerSuite Candidate text empty'; + console.log(message, generateResponseJson); + return response.send({ error: { message } }); + } + + console.log('MakerSuite response:', responseText); + + // Wrap it back to OAI format + const reply = { choices: [{ 'message': { 'content': responseText } }] }; + return response.send(reply); } - - console.log('Palm response:', responseText); - - // Wrap it back to OAI format - const reply = { choices: [{ 'message': { 'content': responseText } }] }; - return response.send(reply); } catch (error) { - console.log('Error communicating with Palm API: ', error); + console.log('Error communicating with MakerSuite API: ', error); if (!response.headersSent) { return response.status(500).send({ error: true }); } @@ -225,7 +277,7 @@ async function sendPalmRequest(request, response) { } /** - * Sends a request to Google AI API. + * Sends a request to AI21 API. * @param {express.Request} request Express request * @param {express.Response} response Express response */ @@ -457,7 +509,7 @@ router.post('/generate', jsonParser, function (request, response) { case CHAT_COMPLETION_SOURCES.CLAUDE: return sendClaudeRequest(request, response); case CHAT_COMPLETION_SOURCES.SCALE: return sendScaleRequest(request, response); case CHAT_COMPLETION_SOURCES.AI21: return sendAI21Request(request, response); - case CHAT_COMPLETION_SOURCES.PALM: return sendPalmRequest(request, response); + case CHAT_COMPLETION_SOURCES.MAKERSUITE: return sendMakerSuiteRequest(request, response); } let apiUrl;