Split oobabooga/mancer sources. Add aphrodite support

This commit is contained in:
Cohee
2023-09-28 19:10:00 +03:00
parent 306cf51da4
commit bb47712696
6 changed files with 294 additions and 110 deletions

111
server.js
View File

@@ -148,9 +148,14 @@ let color = {
white: (mess) => color.byNum(mess, 37)
};
function get_mancer_headers() {
const api_key_mancer = readSecret(SECRET_KEYS.MANCER);
return api_key_mancer ? { "X-API-KEY": api_key_mancer } : {};
function getMancerHeaders() {
const apiKey = readSecret(SECRET_KEYS.MANCER);
return apiKey ? { "X-API-KEY": apiKey } : {};
}
function getAphroditeHeaders() {
const apiKey = readSecret(SECRET_KEYS.APHRODITE);
return apiKey ? { "X-API-KEY": apiKey } : {};
}
function getOverrideHeaders(urlHost) {
@@ -162,6 +167,26 @@ function getOverrideHeaders(urlHost) {
}
}
/**
* Sets additional headers for the request.
* @param {object} request Original request body
* @param {object} args New request arguments
* @param {string|null} server API server for new request
*/
function setAdditionalHeaders(request, args, server) {
let headers = {};
if (request.body.use_mancer) {
headers = getMancerHeaders();
} else if (request.body.use_aphrodite) {
headers = getAphroditeHeaders();
} else {
headers = server ? getOverrideHeaders((new URL(server))?.host) : '';
}
args.headers = Object.assign(args.headers, headers);
}
function humanizedISO8601DateTime(date) {
let baseDate = typeof date === 'number' ? new Date(date) : new Date();
let humanYear = baseDate.getFullYear();
@@ -451,6 +476,52 @@ app.post("/generate", jsonParser, async function (request, response_generate) {
return response_generate.send({ error: true });
});
/**
* @param {string} streamingUrlString Streaming URL
* @param {import('express').Request} request Express request
* @param {import('express').Response} response Express response
* @param {AbortController} controller Abort controller
* @returns
*/
async function sendAphroditeStreamingRequest(streamingUrlString, request, response, controller) {
request.body['stream'] = true;
const args = {
method: 'POST',
body: JSON.stringify(request.body),
headers: { "Content-Type": "application/json" },
signal: controller.signal,
};
setAdditionalHeaders(request, args, streamingUrlString);
try {
const generateResponse = await fetch(streamingUrlString + "/v1/generate", args);
// Pipe remote SSE stream to Express response
generateResponse.body.pipe(response);
request.socket.on('close', function () {
if (generateResponse.body instanceof Readable) generateResponse.body.destroy(); // Close the remote stream
response.end(); // End the Express response
});
generateResponse.body.on('end', function () {
console.log("Streaming request finished");
response.end();
});
} catch (error) {
let value = { error: true, status: error.status, response: error.statusText };
console.log("Aphrodite endpoint error:", error);
if (!response.headersSent) {
return response.send(value);
} else {
return response.end();
}
}
}
//************** Text generation web UI
app.post("/generate_textgenerationwebui", jsonParser, async function (request, response_generate) {
if (!request.body) return response_generate.sendStatus(400);
@@ -470,6 +541,10 @@ app.post("/generate_textgenerationwebui", jsonParser, async function (request, r
if (streamingUrlHeader === undefined) return response_generate.sendStatus(400);
const streamingUrlString = streamingUrlHeader.replace("localhost", "127.0.0.1");
if (request.body.use_aphrodite) {
return sendAphroditeStreamingRequest(streamingUrlString, request, response_generate, controller);
}
response_generate.writeHead(200, {
'Content-Type': 'text/plain;charset=utf-8',
'Transfer-Encoding': 'chunked',
@@ -482,9 +557,20 @@ app.post("/generate_textgenerationwebui", jsonParser, async function (request, r
websocket.on('open', async function () {
console.log('WebSocket opened');
let headers = {};
if (request.body.use_mancer) {
headers = getMancerHeaders();
} else if (request.body.use_aphrodite) {
headers = getAphroditeHeaders();
} else {
headers = getOverrideHeaders(streamingUrl?.host);
}
const combined_args = Object.assign(
{},
request.body.use_mancer ? get_mancer_headers() : getOverrideHeaders(streamingUrl?.host),
headers,
request.body
);
console.log(combined_args);
@@ -568,11 +654,7 @@ app.post("/generate_textgenerationwebui", jsonParser, async function (request, r
signal: controller.signal,
};
if (request.body.use_mancer) {
args.headers = Object.assign(args.headers, get_mancer_headers());
} else {
args.headers = Object.assign(args.headers, getOverrideHeaders((new URL(api_server))?.host));
}
setAdditionalHeaders(request, args, api_server);
try {
const data = await postAsync(api_server + "/v1/generate", args);
@@ -677,11 +759,7 @@ app.post("/getstatus", jsonParser, async function (request, response) {
headers: { "Content-Type": "application/json" }
};
if (main_api == 'textgenerationwebui' && request.body.use_mancer) {
args.headers = Object.assign(args.headers, get_mancer_headers());
} else {
args.headers = Object.assign(args.headers, getOverrideHeaders((new URL(api_server))?.host));
}
setAdditionalHeaders(request, args, api_server);
const url = api_server + "/v1/model";
let version = '';
@@ -3237,9 +3315,8 @@ app.post("/tokenize_via_api", jsonParser, async function (request, response) {
};
if (main_api == 'textgenerationwebui') {
if (request.body.use_mancer) {
args.headers = Object.assign(args.headers, get_mancer_headers());
}
setAdditionalHeaders(request, args, null);
const data = await postAsync(api_server + "/v1/token-count", args);
return response.send({ count: data['results'][0]['tokens'] });
}