Merge branch 'staging' into separate-altscale-endpoints

This commit is contained in:
Cohee 2023-12-14 17:12:19 +02:00
commit 00687a9379
5 changed files with 185 additions and 160 deletions

View File

@ -3665,12 +3665,13 @@ a {
} }
.icon-svg { .icon-svg {
fill: currentColor;
/* Takes on the color of the surrounding text */ /* Takes on the color of the surrounding text */
fill: currentColor;
width: auto; width: auto;
height: 14px; height: 14px;
vertical-align: middle; aspect-ratio: 1;
/* To align with adjacent text */ /* To align with adjacent text */
place-self: center;
} }
.paginationjs { .paginationjs {

View File

@ -217,7 +217,9 @@ if (!cliArguments.disableCsrf) {
if (getConfigValue('enableCorsProxy', false) || cliArguments.corsProxy) { if (getConfigValue('enableCorsProxy', false) || cliArguments.corsProxy) {
const bodyParser = require('body-parser'); const bodyParser = require('body-parser');
app.use(bodyParser.json()); app.use(bodyParser.json({
limit: '200mb',
}));
console.log('Enabling CORS proxy'); console.log('Enabling CORS proxy');
app.use('/proxy/:url(*)', async (req, res) => { app.use('/proxy/:url(*)', async (req, res) => {

View File

@ -4,6 +4,7 @@ const fetch = require('node-fetch').default;
const { jsonParser } = require('../../express-common'); const { jsonParser } = require('../../express-common');
const { CHAT_COMPLETION_SOURCES, PALM_SAFETY } = require('../../constants'); const { CHAT_COMPLETION_SOURCES, PALM_SAFETY } = require('../../constants');
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4 } = require('../../util'); const { forwardFetchResponse, getConfigValue, tryParse, uuidv4 } = require('../../util');
const { convertClaudePrompt, convertTextCompletionPrompt } = require('../prompt-converters');
const { readSecret, SECRET_KEYS } = require('../secrets'); const { readSecret, SECRET_KEYS } = require('../secrets');
const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sentencepieceTokenizers, TEXT_COMPLETION_MODELS } = require('../tokenizers'); const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sentencepieceTokenizers, TEXT_COMPLETION_MODELS } = require('../tokenizers');
@ -12,89 +13,15 @@ const API_OPENAI = 'https://api.openai.com/v1';
const API_CLAUDE = 'https://api.anthropic.com/v1'; const API_CLAUDE = 'https://api.anthropic.com/v1';
/** /**
* Convert a prompt from the ChatML objects to the format used by Claude. * Sends a request to Claude API.
* @param {object[]} messages Array of messages * @param {express.Request} request Express request
* @param {boolean} addHumanPrefix Add Human prefix * @param {express.Response} response Express response
* @param {boolean} addAssistantPostfix Add Assistant postfix
* @param {boolean} withSystemPrompt Build system prompt before "\n\nHuman: "
* @returns {string} Prompt for Claude
* @copyright Prompt Conversion script taken from RisuAI by kwaroran (GPLv3).
*/
function convertClaudePrompt(messages, addHumanPrefix, addAssistantPostfix, withSystemPrompt) {
// Claude doesn't support message names, so we'll just add them to the message content.
for (const message of messages) {
if (message.name && message.role !== 'system') {
message.content = message.name + ': ' + message.content;
delete message.name;
}
}
let systemPrompt = '';
if (withSystemPrompt) {
let lastSystemIdx = -1;
for (let i = 0; i < messages.length - 1; i++) {
const message = messages[i];
if (message.role === 'system' && !message.name) {
systemPrompt += message.content + '\n\n';
} else {
lastSystemIdx = i - 1;
break;
}
}
if (lastSystemIdx >= 0) {
messages.splice(0, lastSystemIdx + 1);
}
}
let requestPrompt = messages.map((v) => {
let prefix = '';
switch (v.role) {
case 'assistant':
prefix = '\n\nAssistant: ';
break;
case 'user':
prefix = '\n\nHuman: ';
break;
case 'system':
// According to the Claude docs, H: and A: should be used for example conversations.
if (v.name === 'example_assistant') {
prefix = '\n\nA: ';
} else if (v.name === 'example_user') {
prefix = '\n\nH: ';
} else {
prefix = '\n\n';
}
break;
}
return prefix + v.content;
}).join('');
if (addHumanPrefix) {
requestPrompt = '\n\nHuman: ' + requestPrompt;
}
if (addAssistantPostfix) {
requestPrompt = requestPrompt + '\n\nAssistant: ';
}
if (withSystemPrompt) {
requestPrompt = systemPrompt + requestPrompt;
}
return requestPrompt;
}
/**
* @param {express.Request} request
* @param {express.Response} response
*/ */
async function sendClaudeRequest(request, response) { async function sendClaudeRequest(request, response) {
const apiUrl = new URL(request.body.reverse_proxy || API_CLAUDE).toString();
const apiKey = request.body.reverse_proxy ? request.body.proxy_password : readSecret(SECRET_KEYS.CLAUDE);
const api_url = new URL(request.body.reverse_proxy || API_CLAUDE).toString(); if (!apiKey) {
const api_key_claude = request.body.reverse_proxy ? request.body.proxy_password : readSecret(SECRET_KEYS.CLAUDE);
if (!api_key_claude) {
console.log('Claude API key is missing.'); console.log('Claude API key is missing.');
return response.status(400).send({ error: true }); return response.status(400).send({ error: true });
} }
@ -121,7 +48,7 @@ async function sendClaudeRequest(request, response) {
stop_sequences.push(...request.body.stop); stop_sequences.push(...request.body.stop);
} }
const generateResponse = await fetch(api_url + '/complete', { const generateResponse = await fetch(apiUrl + '/complete', {
method: 'POST', method: 'POST',
signal: controller.signal, signal: controller.signal,
body: JSON.stringify({ body: JSON.stringify({
@ -137,7 +64,7 @@ async function sendClaudeRequest(request, response) {
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
'anthropic-version': '2023-06-01', 'anthropic-version': '2023-06-01',
'x-api-key': api_key_claude, 'x-api-key': apiKey,
}, },
timeout: 0, timeout: 0,
}); });
@ -167,37 +94,21 @@ async function sendClaudeRequest(request, response) {
} }
} }
function convertChatMLPrompt(messages) { /**
if (typeof messages === 'string') { * Sends a request to Scale Spellbook API.
return messages; * @param {import("express").Request} request Express request
} * @param {import("express").Response} response Express response
*/
const messageStrings = [];
messages.forEach(m => {
if (m.role === 'system' && m.name === undefined) {
messageStrings.push('System: ' + m.content);
}
else if (m.role === 'system' && m.name !== undefined) {
messageStrings.push(m.name + ': ' + m.content);
}
else {
messageStrings.push(m.role + ': ' + m.content);
}
});
return messageStrings.join('\n') + '\nassistant:';
}
async function sendScaleRequest(request, response) { async function sendScaleRequest(request, response) {
const apiUrl = new URL(request.body.api_url_scale).toString();
const apiKey = readSecret(SECRET_KEYS.SCALE);
const api_url = new URL(request.body.api_url_scale).toString(); if (!apiKey) {
const api_key_scale = readSecret(SECRET_KEYS.SCALE);
if (!api_key_scale) {
console.log('Scale API key is missing.'); console.log('Scale API key is missing.');
return response.status(400).send({ error: true }); return response.status(400).send({ error: true });
} }
const requestPrompt = convertChatMLPrompt(request.body.messages); const requestPrompt = convertTextCompletionPrompt(request.body.messages);
console.log('Scale request:', requestPrompt); console.log('Scale request:', requestPrompt);
try { try {
@ -207,12 +118,12 @@ async function sendScaleRequest(request, response) {
controller.abort(); controller.abort();
}); });
const generateResponse = await fetch(api_url, { const generateResponse = await fetch(apiUrl, {
method: 'POST', method: 'POST',
body: JSON.stringify({ input: { input: requestPrompt } }), body: JSON.stringify({ input: { input: requestPrompt } }),
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
'Authorization': `Basic ${api_key_scale}`, 'Authorization': `Basic ${apiKey}`,
}, },
timeout: 0, timeout: 0,
}); });
@ -236,8 +147,9 @@ async function sendScaleRequest(request, response) {
} }
/** /**
* @param {express.Request} request * Sends a request to Google AI API.
* @param {express.Response} response * @param {express.Request} request Express request
* @param {express.Response} response Express response
*/ */
async function sendPalmRequest(request, response) { async function sendPalmRequest(request, response) {
const api_key_palm = readSecret(SECRET_KEYS.PALM); const api_key_palm = readSecret(SECRET_KEYS.PALM);
@ -285,15 +197,15 @@ async function sendPalmRequest(request, response) {
} }
const generateResponseJson = await generateResponse.json(); const generateResponseJson = await generateResponse.json();
const responseText = generateResponseJson?.candidates[0]?.output; const responseText = generateResponseJson?.candidates?.[0]?.output;
if (!responseText) { if (!responseText) {
console.log('Palm API returned no response', generateResponseJson); console.log('Palm API returned no response', generateResponseJson);
let message = `Palm API returned no response: ${JSON.stringify(generateResponseJson)}`; let message = `Palm API returned no response: ${JSON.stringify(generateResponseJson)}`;
// Check for filters // Check for filters
if (generateResponseJson?.filters[0]?.message) { if (generateResponseJson?.filters?.[0]?.reason) {
message = `Palm filter triggered: ${generateResponseJson.filters[0].message}`; message = `Palm filter triggered: ${generateResponseJson.filters[0].reason}`;
} }
return response.send({ error: { message } }); return response.send({ error: { message } });
@ -312,6 +224,11 @@ async function sendPalmRequest(request, response) {
} }
} }
/**
* Sends a request to Google AI API.
* @param {express.Request} request Express request
* @param {express.Response} response Express response
*/
async function sendAI21Request(request, response) { async function sendAI21Request(request, response) {
if (!request.body) return response.sendStatus(400); if (!request.body) return response.sendStatus(400);
const controller = new AbortController(); const controller = new AbortController();
@ -533,24 +450,24 @@ router.post('/bias', jsonParser, async function (request, response) {
}); });
router.post('/generate', jsonParser, function (request, response_generate_openai) { router.post('/generate', jsonParser, function (request, response) {
if (!request.body) return response_generate_openai.status(400).send({ error: true }); if (!request.body) return response.status(400).send({ error: true });
switch (request.body.chat_completion_source) { switch (request.body.chat_completion_source) {
case CHAT_COMPLETION_SOURCES.CLAUDE: return sendClaudeRequest(request, response_generate_openai); case CHAT_COMPLETION_SOURCES.CLAUDE: return sendClaudeRequest(request, response);
case CHAT_COMPLETION_SOURCES.SCALE: return sendScaleRequest(request, response_generate_openai); case CHAT_COMPLETION_SOURCES.SCALE: return sendScaleRequest(request, response);
case CHAT_COMPLETION_SOURCES.AI21: return sendAI21Request(request, response_generate_openai); case CHAT_COMPLETION_SOURCES.AI21: return sendAI21Request(request, response);
case CHAT_COMPLETION_SOURCES.PALM: return sendPalmRequest(request, response_generate_openai); case CHAT_COMPLETION_SOURCES.PALM: return sendPalmRequest(request, response);
} }
let api_url; let apiUrl;
let api_key_openai; let apiKey;
let headers; let headers;
let bodyParams; let bodyParams;
if (request.body.chat_completion_source !== CHAT_COMPLETION_SOURCES.OPENROUTER) { if (request.body.chat_completion_source !== CHAT_COMPLETION_SOURCES.OPENROUTER) {
api_url = new URL(request.body.reverse_proxy || API_OPENAI).toString(); apiUrl = new URL(request.body.reverse_proxy || API_OPENAI).toString();
api_key_openai = request.body.reverse_proxy ? request.body.proxy_password : readSecret(SECRET_KEYS.OPENAI); apiKey = request.body.reverse_proxy ? request.body.proxy_password : readSecret(SECRET_KEYS.OPENAI);
headers = {}; headers = {};
bodyParams = {}; bodyParams = {};
@ -558,8 +475,8 @@ router.post('/generate', jsonParser, function (request, response_generate_openai
bodyParams['user'] = uuidv4(); bodyParams['user'] = uuidv4();
} }
} else { } else {
api_url = 'https://openrouter.ai/api/v1'; apiUrl = 'https://openrouter.ai/api/v1';
api_key_openai = readSecret(SECRET_KEYS.OPENROUTER); apiKey = readSecret(SECRET_KEYS.OPENROUTER);
// OpenRouter needs to pass the referer: https://openrouter.ai/docs // OpenRouter needs to pass the referer: https://openrouter.ai/docs
headers = { 'HTTP-Referer': request.headers.referer }; headers = { 'HTTP-Referer': request.headers.referer };
bodyParams = { 'transforms': ['middle-out'] }; bodyParams = { 'transforms': ['middle-out'] };
@ -569,9 +486,9 @@ router.post('/generate', jsonParser, function (request, response_generate_openai
} }
} }
if (!api_key_openai && !request.body.reverse_proxy) { if (!apiKey && !request.body.reverse_proxy) {
console.log('OpenAI API key is missing.'); console.log('OpenAI API key is missing.');
return response_generate_openai.status(400).send({ error: true }); return response.status(400).send({ error: true });
} }
// Add custom stop sequences // Add custom stop sequences
@ -580,10 +497,10 @@ router.post('/generate', jsonParser, function (request, response_generate_openai
} }
const isTextCompletion = Boolean(request.body.model && TEXT_COMPLETION_MODELS.includes(request.body.model)) || typeof request.body.messages === 'string'; const isTextCompletion = Boolean(request.body.model && TEXT_COMPLETION_MODELS.includes(request.body.model)) || typeof request.body.messages === 'string';
const textPrompt = isTextCompletion ? convertChatMLPrompt(request.body.messages) : ''; const textPrompt = isTextCompletion ? convertTextCompletionPrompt(request.body.messages) : '';
const endpointUrl = isTextCompletion && request.body.chat_completion_source !== CHAT_COMPLETION_SOURCES.OPENROUTER ? const endpointUrl = isTextCompletion && request.body.chat_completion_source !== CHAT_COMPLETION_SOURCES.OPENROUTER ?
`${api_url}/completions` : `${apiUrl}/completions` :
`${api_url}/chat/completions`; `${apiUrl}/chat/completions`;
const controller = new AbortController(); const controller = new AbortController();
request.socket.removeAllListeners('close'); request.socket.removeAllListeners('close');
@ -596,7 +513,7 @@ router.post('/generate', jsonParser, function (request, response_generate_openai
method: 'post', method: 'post',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
'Authorization': 'Bearer ' + api_key_openai, 'Authorization': 'Bearer ' + apiKey,
...headers, ...headers,
}, },
body: JSON.stringify({ body: JSON.stringify({
@ -621,52 +538,55 @@ router.post('/generate', jsonParser, function (request, response_generate_openai
console.log(JSON.parse(String(config.body))); console.log(JSON.parse(String(config.body)));
makeRequest(config, response_generate_openai, request); makeRequest(config, response, request);
/** /**
* * Makes a fetch request to the OpenAI API endpoint.
* @param {*} config * @param {import('node-fetch').RequestInit} config Fetch config
* @param {express.Response} response_generate_openai * @param {express.Response} response Express response
* @param {express.Request} request * @param {express.Request} request Express request
* @param {Number} retries * @param {Number} retries Number of retries left
* @param {Number} timeout * @param {Number} timeout Request timeout in ms
*/ */
async function makeRequest(config, response_generate_openai, request, retries = 5, timeout = 5000) { async function makeRequest(config, response, request, retries = 5, timeout = 5000) {
try { try {
const fetchResponse = await fetch(endpointUrl, config); const fetchResponse = await fetch(endpointUrl, config);
if (request.body.stream) { if (request.body.stream) {
console.log('Streaming request in progress'); console.log('Streaming request in progress');
forwardFetchResponse(fetchResponse, response_generate_openai); forwardFetchResponse(fetchResponse, response);
return; return;
} }
if (fetchResponse.ok) { if (fetchResponse.ok) {
let json = await fetchResponse.json(); let json = await fetchResponse.json();
response_generate_openai.send(json); response.send(json);
console.log(json); console.log(json);
console.log(json?.choices[0]?.message); console.log(json?.choices[0]?.message);
} else if (fetchResponse.status === 429 && retries > 0) { } else if (fetchResponse.status === 429 && retries > 0) {
console.log(`Out of quota, retrying in ${Math.round(timeout / 1000)}s`); console.log(`Out of quota, retrying in ${Math.round(timeout / 1000)}s`);
setTimeout(() => { setTimeout(() => {
timeout *= 2; timeout *= 2;
makeRequest(config, response_generate_openai, request, retries - 1, timeout); makeRequest(config, response, request, retries - 1, timeout);
}, timeout); }, timeout);
} else { } else {
await handleErrorResponse(fetchResponse); await handleErrorResponse(fetchResponse);
} }
} catch (error) { } catch (error) {
console.log('Generation failed', error); console.log('Generation failed', error);
if (!response_generate_openai.headersSent) { if (!response.headersSent) {
response_generate_openai.send({ error: true }); response.send({ error: true });
} else { } else {
response_generate_openai.end(); response.end();
} }
} }
} }
async function handleErrorResponse(response) { /**
const responseText = await response.text(); * @param {import("node-fetch").Response} errorResponse
*/
async function handleErrorResponse(errorResponse) {
const responseText = await errorResponse.text();
const errorData = tryParse(responseText); const errorData = tryParse(responseText);
const statusMessages = { const statusMessages = {
@ -680,21 +600,20 @@ router.post('/generate', jsonParser, function (request, response_generate_openai
502: 'Bad gateway', 502: 'Bad gateway',
}; };
const message = errorData?.error?.message || statusMessages[response.status] || 'Unknown error occurred'; const message = errorData?.error?.message || statusMessages[errorResponse.status] || 'Unknown error occurred';
const quota_error = response.status === 429 && errorData?.error?.type === 'insufficient_quota'; const quota_error = errorResponse.status === 429 && errorData?.error?.type === 'insufficient_quota';
console.log(message); console.log(message);
if (!response_generate_openai.headersSent) { if (!response.headersSent) {
response_generate_openai.send({ error: { message }, quota_error: quota_error }); response.send({ error: { message }, quota_error: quota_error });
} else if (!response_generate_openai.writableEnded) { } else if (!response.writableEnded) {
response_generate_openai.write(response); response.write(errorResponse);
} else { } else {
response_generate_openai.end(); response.end();
} }
} }
}); });
module.exports = { module.exports = {
router, router,
convertClaudePrompt,
}; };

View File

@ -0,0 +1,103 @@
/**
* Convert a prompt from the ChatML objects to the format used by Claude.
* @param {object[]} messages Array of messages
* @param {boolean} addHumanPrefix Add Human prefix
* @param {boolean} addAssistantPostfix Add Assistant postfix
* @param {boolean} withSystemPrompt Build system prompt before "\n\nHuman: "
* @returns {string} Prompt for Claude
* @copyright Prompt Conversion script taken from RisuAI by kwaroran (GPLv3).
*/
function convertClaudePrompt(messages, addHumanPrefix, addAssistantPostfix, withSystemPrompt) {
// Claude doesn't support message names, so we'll just add them to the message content.
for (const message of messages) {
if (message.name && message.role !== 'system') {
message.content = message.name + ': ' + message.content;
delete message.name;
}
}
let systemPrompt = '';
if (withSystemPrompt) {
let lastSystemIdx = -1;
for (let i = 0; i < messages.length - 1; i++) {
const message = messages[i];
if (message.role === 'system' && !message.name) {
systemPrompt += message.content + '\n\n';
} else {
lastSystemIdx = i - 1;
break;
}
}
if (lastSystemIdx >= 0) {
messages.splice(0, lastSystemIdx + 1);
}
}
let requestPrompt = messages.map((v) => {
let prefix = '';
switch (v.role) {
case 'assistant':
prefix = '\n\nAssistant: ';
break;
case 'user':
prefix = '\n\nHuman: ';
break;
case 'system':
// According to the Claude docs, H: and A: should be used for example conversations.
if (v.name === 'example_assistant') {
prefix = '\n\nA: ';
} else if (v.name === 'example_user') {
prefix = '\n\nH: ';
} else {
prefix = '\n\n';
}
break;
}
return prefix + v.content;
}).join('');
if (addHumanPrefix) {
requestPrompt = '\n\nHuman: ' + requestPrompt;
}
if (addAssistantPostfix) {
requestPrompt = requestPrompt + '\n\nAssistant: ';
}
if (withSystemPrompt) {
requestPrompt = systemPrompt + requestPrompt;
}
return requestPrompt;
}
/**
* Convert a prompt from the ChatML objects to the format used by Text Completion API.
* @param {object[]} messages Array of messages
* @returns {string} Prompt for Text Completion API
*/
function convertTextCompletionPrompt(messages) {
if (typeof messages === 'string') {
return messages;
}
const messageStrings = [];
messages.forEach(m => {
if (m.role === 'system' && m.name === undefined) {
messageStrings.push('System: ' + m.content);
}
else if (m.role === 'system' && m.name !== undefined) {
messageStrings.push(m.name + ': ' + m.content);
}
else {
messageStrings.push(m.role + ': ' + m.content);
}
});
return messageStrings.join('\n') + '\nassistant:';
}
module.exports = {
convertClaudePrompt,
convertTextCompletionPrompt,
};

View File

@ -4,7 +4,7 @@ const express = require('express');
const { SentencePieceProcessor } = require('@agnai/sentencepiece-js'); const { SentencePieceProcessor } = require('@agnai/sentencepiece-js');
const tiktoken = require('@dqbd/tiktoken'); const tiktoken = require('@dqbd/tiktoken');
const { Tokenizer } = require('@agnai/web-tokenizers'); const { Tokenizer } = require('@agnai/web-tokenizers');
const { convertClaudePrompt } = require('./textgen/chat-completions'); const { convertClaudePrompt } = require('./prompt-converters');
const { readSecret, SECRET_KEYS } = require('./secrets'); const { readSecret, SECRET_KEYS } = require('./secrets');
const { TEXTGEN_TYPES } = require('../constants'); const { TEXTGEN_TYPES } = require('../constants');
const { jsonParser } = require('../express-common'); const { jsonParser } = require('../express-common');