From 055d6c4337807e9c9e0925b04257fe32bd3130a3 Mon Sep 17 00:00:00 2001 From: valadaptive Date: Thu, 7 Dec 2023 18:06:17 -0500 Subject: [PATCH] Properly forward status codes from streams --- server.js | 63 +++++++++++++--------------------------- src/endpoints/novelai.js | 15 ++-------- src/util.js | 23 +++++++++++++++ 3 files changed, 45 insertions(+), 56 deletions(-) diff --git a/server.js b/server.js index 5c25c2a0c..589bf481c 100644 --- a/server.js +++ b/server.js @@ -7,7 +7,6 @@ const http = require('http'); const https = require('https'); const path = require('path'); const util = require('util'); -const { Readable } = require('stream'); // cli/fs related library imports const open = require('open'); @@ -45,7 +44,20 @@ const basicAuthMiddleware = require('./src/middleware/basicAuthMiddleware'); const { jsonParser, urlencodedParser } = require('./src/express-common.js'); const contentManager = require('./src/endpoints/content-manager'); const { readSecret, migrateSecrets, SECRET_KEYS } = require('./src/endpoints/secrets'); -const { delay, getVersion, getConfigValue, color, uuidv4, tryParse, clientRelativePath, removeFileExtension, generateTimestamp, removeOldBackups, getImages } = require('./src/util'); +const { + delay, + getVersion, + getConfigValue, + color, + uuidv4, + tryParse, + clientRelativePath, + removeFileExtension, + generateTimestamp, + removeOldBackups, + getImages, + forwardFetchResponse, +} = require('./src/util'); const { ensureThumbnailCache } = require('./src/endpoints/thumbnails'); const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS, getSentencepiceTokenizer, sentencepieceTokenizers } = require('./src/endpoints/tokenizers'); const { convertClaudePrompt } = require('./src/chat-completion'); @@ -307,9 +319,7 @@ if (getConfigValue('enableCorsProxy', false) || cliArguments.corsProxy) { }); // Copy over relevant response params to the proxy response - res.statusCode = response.status; - res.statusMessage = response.statusText; - response.body.pipe(res); + forwardFetchResponse(response, res); } catch (error) { res.status(500).send('Error occurred while trying to proxy to: ' + url + ' ' + error); @@ -457,18 +467,9 @@ app.post('/generate', jsonParser, async function (request, response_generate) { const response = await fetch(url, { method: 'POST', timeout: 0, ...args }); if (request.body.streaming) { - request.socket.on('close', function () { - if (response.body instanceof Readable) response.body.destroy(); // Close the remote stream - response_generate.end(); // End the Express response - }); - - response.body.on('end', function () { - console.log('Streaming request finished'); - response_generate.end(); - }); - // Pipe remote SSE stream to Express response - return response.body.pipe(response_generate); + forwardFetchResponse(response, response_generate); + return; } else { if (!response.ok) { const errorText = await response.text(); @@ -666,17 +667,7 @@ app.post('/api/textgenerationwebui/generate', jsonParser, async function (reques if (request.body.stream) { const completionsStream = await fetch(url, args); // Pipe remote SSE stream to Express response - completionsStream.body.pipe(response_generate); - - request.socket.on('close', function () { - if (completionsStream.body instanceof Readable) completionsStream.body.destroy(); // Close the remote stream - response_generate.end(); // End the Express response - }); - - completionsStream.body.on('end', function () { - console.log('Streaming request finished'); - response_generate.end(); - }); + forwardFetchResponse(completionsStream, response_generate); } else { const completionsReply = await fetch(url, args); @@ -1427,17 +1418,7 @@ async function sendClaudeRequest(request, response) { if (request.body.stream) { // Pipe remote SSE stream to Express response - generateResponse.body.pipe(response); - - request.socket.on('close', function () { - if (generateResponse.body instanceof Readable) generateResponse.body.destroy(); // Close the remote stream - response.end(); // End the Express response - }); - - generateResponse.body.on('end', function () { - console.log('Streaming request finished'); - response.end(); - }); + forwardFetchResponse(generateResponse, response); } else { if (!generateResponse.ok) { console.log(`Claude API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`); @@ -1640,11 +1621,7 @@ app.post('/generate_openai', jsonParser, function (request, response_generate_op if (fetchResponse.ok) { if (request.body.stream) { console.log('Streaming request in progress'); - fetchResponse.body.pipe(response_generate_openai); - fetchResponse.body.on('end', () => { - console.log('Streaming request finished'); - response_generate_openai.end(); - }); + forwardFetchResponse(fetchResponse, response_generate_openai); } else { let json = await fetchResponse.json(); response_generate_openai.send(json); diff --git a/src/endpoints/novelai.js b/src/endpoints/novelai.js index 2071c4c5b..89b460042 100644 --- a/src/endpoints/novelai.js +++ b/src/endpoints/novelai.js @@ -1,9 +1,8 @@ const fetch = require('node-fetch').default; const express = require('express'); const util = require('util'); -const { Readable } = require('stream'); const { readSecret, SECRET_KEYS } = require('./secrets'); -const { readAllChunks, extractFileFromZipBuffer } = require('../util'); +const { readAllChunks, extractFileFromZipBuffer, forwardFetchResponse } = require('../util'); const { jsonParser } = require('../express-common'); const API_NOVELAI = 'https://api.novelai.net'; @@ -188,17 +187,7 @@ router.post('/generate', jsonParser, async function (req, res) { if (req.body.streaming) { // Pipe remote SSE stream to Express response - response.body.pipe(res); - - req.socket.on('close', function () { - if (response.body instanceof Readable) response.body.destroy(); // Close the remote stream - res.end(); // End the Express response - }); - - response.body.on('end', function () { - console.log('Streaming request finished'); - res.end(); - }); + forwardFetchResponse(response, res); } else { if (!response.ok) { const text = await response.text(); diff --git a/src/util.js b/src/util.js index bc290c9ae..c6f344c71 100644 --- a/src/util.js +++ b/src/util.js @@ -6,6 +6,7 @@ const yauzl = require('yauzl'); const mime = require('mime-types'); const yaml = require('yaml'); const { default: simpleGit } = require('simple-git'); +const { Readable } = require('stream'); const { DIRECTORIES } = require('./constants'); @@ -346,6 +347,27 @@ function getImages(path) { .sort(Intl.Collator().compare); } +/** + * Pipe a fetch() response to an Express.js Response, including status code. + * @param {Response} from The Fetch API response to pipe from. + * @param {Express.Response} to The Express response to pipe to. + */ +function forwardFetchResponse(from, to) { + to.statusCode = from.status; + to.statusMessage = from.statusText; + from.body.pipe(to); + + to.socket.on('close', function () { + if (from.body instanceof Readable) from.body.destroy(); // Close the remote stream + to.end(); // End the Express response + }); + + from.body.on('end', function () { + console.log('Streaming request finished'); + to.end(); + }); +} + module.exports = { getConfig, getConfigValue, @@ -365,4 +387,5 @@ module.exports = { generateTimestamp, removeOldBackups, getImages, + forwardFetchResponse, };