This commit is contained in:
based
2023-12-15 01:39:12 +10:00
parent 698850b514
commit 60880cfd4d
2 changed files with 93 additions and 41 deletions

View File

@ -31,7 +31,7 @@ import {
system_message_types, system_message_types,
this_chid, this_chid,
} from '../script.js'; } from '../script.js';
import {groups, selected_group} from './group-chats.js'; import { groups, selected_group } from './group-chats.js';
import { import {
chatCompletionDefaultPrompts, chatCompletionDefaultPrompts,
@ -41,8 +41,8 @@ import {
promptManagerDefaultPromptOrders, promptManagerDefaultPromptOrders,
} from './PromptManager.js'; } from './PromptManager.js';
import {getCustomStoppingStrings, persona_description_positions, power_user} from './power-user.js'; import { getCustomStoppingStrings, persona_description_positions, power_user } from './power-user.js';
import {SECRET_KEYS, secret_state, writeSecret} from './secrets.js'; import { SECRET_KEYS, secret_state, writeSecret } from './secrets.js';
import EventSourceStream from './sse-stream.js'; import EventSourceStream from './sse-stream.js';
import { import {

View File

@ -2,9 +2,9 @@ const express = require('express');
const fetch = require('node-fetch').default; const fetch = require('node-fetch').default;
const { jsonParser } = require('../../express-common'); const { jsonParser } = require('../../express-common');
const { CHAT_COMPLETION_SOURCES, PALM_SAFETY } = require('../../constants'); const { CHAT_COMPLETION_SOURCES, MAKERSUITE_SAFETY } = require('../../constants');
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4 } = require('../../util'); const { forwardFetchResponse, getConfigValue, tryParse, uuidv4 } = require('../../util');
const { convertClaudePrompt, convertTextCompletionPrompt } = require('../prompt-converters'); const { convertClaudePrompt, convertGooglePrompt, convertTextCompletionPrompt } = require('../prompt-converters');
const { readSecret, SECRET_KEYS } = require('../secrets'); const { readSecret, SECRET_KEYS } = require('../secrets');
const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sentencepieceTokenizers, TEXT_COMPLETION_MODELS } = require('../tokenizers'); const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sentencepieceTokenizers, TEXT_COMPLETION_MODELS } = require('../tokenizers');
@ -151,28 +151,35 @@ async function sendScaleRequest(request, response) {
* @param {express.Request} request Express request * @param {express.Request} request Express request
* @param {express.Response} response Express response * @param {express.Response} response Express response
*/ */
async function sendPalmRequest(request, response) { /**
const api_key_palm = readSecret(SECRET_KEYS.PALM); * @param {express.Request} request
* @param {express.Response} response
*/
async function sendMakerSuiteRequest(request, response) {
const api_key_makersuite = readSecret(SECRET_KEYS.MAKERSUITE);
if (!api_key_palm) { if (!api_key_makersuite) {
console.log('Palm API key is missing.'); console.log('MakerSuite API key is missing.');
return response.status(400).send({ error: true }); return response.status(400).send({ error: true });
} }
const body = { const google_model = request.body.model;
prompt: { const should_stream = request.body.stream;
text: request.body.messages,
}, const generationConfig = {
stopSequences: request.body.stop, stopSequences: request.body.stop,
safetySettings: PALM_SAFETY, candidateCount: 1,
maxOutputTokens: request.body.max_tokens,
temperature: request.body.temperature, temperature: request.body.temperature,
topP: request.body.top_p, topP: request.body.top_p,
topK: request.body.top_k || undefined, topK: request.body.top_k || undefined,
maxOutputTokens: request.body.max_tokens,
candidate_count: 1,
}; };
console.log('Palm request:', body); const body = {
contents: convertGooglePrompt(request.body.messages, google_model),
safetySettings: MAKERSUITE_SAFETY,
generationConfig: generationConfig,
};
try { try {
const controller = new AbortController(); const controller = new AbortController();
@ -181,7 +188,7 @@ async function sendPalmRequest(request, response) {
controller.abort(); controller.abort();
}); });
const generateResponse = await fetch(`https://generativelanguage.googleapis.com/v1beta2/models/text-bison-001:generateText?key=${api_key_palm}`, { const generateResponse = await fetch(`https://generativelanguage.googleapis.com/v1beta/models/${google_model}:${should_stream ? 'streamGenerateContent' : 'generateContent'}?key=${api_key_makersuite}`, {
body: JSON.stringify(body), body: JSON.stringify(body),
method: 'POST', method: 'POST',
headers: { headers: {
@ -190,34 +197,79 @@ async function sendPalmRequest(request, response) {
signal: controller.signal, signal: controller.signal,
timeout: 0, timeout: 0,
}); });
// have to do this because of their busted ass streaming endpoint
if (should_stream) {
try {
let partialData = '';
generateResponse.body.on('data', (data) => {
const chunk = data.toString();
if (chunk.startsWith(',') || chunk.endsWith(',') || chunk.startsWith('[') || chunk.endsWith(']')) {
partialData = chunk.slice(1);
} else {
partialData += chunk;
}
while (true) {
let json;
try {
json = JSON.parse(partialData);
} catch (e) {
break;
}
response.write(JSON.stringify(json));
partialData = '';
}
});
if (!generateResponse.ok) { request.socket.on('close', function () {
console.log(`Palm API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`); generateResponse.body.destroy();
return response.status(generateResponse.status).send({ error: true }); response.end();
} });
const generateResponseJson = await generateResponse.json(); generateResponse.body.on('end', () => {
const responseText = generateResponseJson?.candidates?.[0]?.output; console.log('Streaming request finished');
response.end();
});
if (!responseText) { } catch (error) {
console.log('Palm API returned no response', generateResponseJson); console.log('Error forwarding streaming response:', error);
let message = `Palm API returned no response: ${JSON.stringify(generateResponseJson)}`; if (!response.headersSent) {
return response.status(500).send({ error: true });
// Check for filters }
if (generateResponseJson?.filters?.[0]?.reason) { }
message = `Palm filter triggered: ${generateResponseJson.filters[0].reason}`; } else {
if (!generateResponse.ok) {
console.log(`MakerSuite API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`);
return response.status(generateResponse.status).send({ error: true });
} }
return response.send({ error: { message } }); const generateResponseJson = await generateResponse.json();
const candidates = generateResponseJson?.candidates;
if (!candidates || candidates.length === 0) {
let message = 'MakerSuite API returned no candidate';
console.log(message, generateResponseJson);
if (generateResponseJson?.promptFeedback?.blockReason) {
message += `\nPrompt was blocked due to : ${generateResponseJson.promptFeedback.blockReason}`;
}
return response.send({ error: { message } });
}
const responseContent = candidates[0].content;
const responseText = responseContent.parts[0].text;
if (!responseText) {
let message = 'MakerSuite Candidate text empty';
console.log(message, generateResponseJson);
return response.send({ error: { message } });
}
console.log('MakerSuite response:', responseText);
// Wrap it back to OAI format
const reply = { choices: [{ 'message': { 'content': responseText } }] };
return response.send(reply);
} }
console.log('Palm response:', responseText);
// Wrap it back to OAI format
const reply = { choices: [{ 'message': { 'content': responseText } }] };
return response.send(reply);
} catch (error) { } catch (error) {
console.log('Error communicating with Palm API: ', error); console.log('Error communicating with MakerSuite API: ', error);
if (!response.headersSent) { if (!response.headersSent) {
return response.status(500).send({ error: true }); return response.status(500).send({ error: true });
} }
@ -225,7 +277,7 @@ async function sendPalmRequest(request, response) {
} }
/** /**
* Sends a request to Google AI API. * Sends a request to AI21 API.
* @param {express.Request} request Express request * @param {express.Request} request Express request
* @param {express.Response} response Express response * @param {express.Response} response Express response
*/ */
@ -457,7 +509,7 @@ router.post('/generate', jsonParser, function (request, response) {
case CHAT_COMPLETION_SOURCES.CLAUDE: return sendClaudeRequest(request, response); case CHAT_COMPLETION_SOURCES.CLAUDE: return sendClaudeRequest(request, response);
case CHAT_COMPLETION_SOURCES.SCALE: return sendScaleRequest(request, response); case CHAT_COMPLETION_SOURCES.SCALE: return sendScaleRequest(request, response);
case CHAT_COMPLETION_SOURCES.AI21: return sendAI21Request(request, response); case CHAT_COMPLETION_SOURCES.AI21: return sendAI21Request(request, response);
case CHAT_COMPLETION_SOURCES.PALM: return sendPalmRequest(request, response); case CHAT_COMPLETION_SOURCES.MAKERSUITE: return sendMakerSuiteRequest(request, response);
} }
let apiUrl; let apiUrl;