Properly forward status codes from streams

This commit is contained in:
valadaptive 2023-12-07 18:06:17 -05:00
parent 5569a63595
commit 055d6c4337
3 changed files with 45 additions and 56 deletions

View File

@ -7,7 +7,6 @@ const http = require('http');
const https = require('https'); const https = require('https');
const path = require('path'); const path = require('path');
const util = require('util'); const util = require('util');
const { Readable } = require('stream');
// cli/fs related library imports // cli/fs related library imports
const open = require('open'); const open = require('open');
@ -45,7 +44,20 @@ const basicAuthMiddleware = require('./src/middleware/basicAuthMiddleware');
const { jsonParser, urlencodedParser } = require('./src/express-common.js'); const { jsonParser, urlencodedParser } = require('./src/express-common.js');
const contentManager = require('./src/endpoints/content-manager'); const contentManager = require('./src/endpoints/content-manager');
const { readSecret, migrateSecrets, SECRET_KEYS } = require('./src/endpoints/secrets'); const { readSecret, migrateSecrets, SECRET_KEYS } = require('./src/endpoints/secrets');
const { delay, getVersion, getConfigValue, color, uuidv4, tryParse, clientRelativePath, removeFileExtension, generateTimestamp, removeOldBackups, getImages } = require('./src/util'); const {
delay,
getVersion,
getConfigValue,
color,
uuidv4,
tryParse,
clientRelativePath,
removeFileExtension,
generateTimestamp,
removeOldBackups,
getImages,
forwardFetchResponse,
} = require('./src/util');
const { ensureThumbnailCache } = require('./src/endpoints/thumbnails'); const { ensureThumbnailCache } = require('./src/endpoints/thumbnails');
const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS, getSentencepiceTokenizer, sentencepieceTokenizers } = require('./src/endpoints/tokenizers'); const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS, getSentencepiceTokenizer, sentencepieceTokenizers } = require('./src/endpoints/tokenizers');
const { convertClaudePrompt } = require('./src/chat-completion'); const { convertClaudePrompt } = require('./src/chat-completion');
@ -307,9 +319,7 @@ if (getConfigValue('enableCorsProxy', false) || cliArguments.corsProxy) {
}); });
// Copy over relevant response params to the proxy response // Copy over relevant response params to the proxy response
res.statusCode = response.status; forwardFetchResponse(response, res);
res.statusMessage = response.statusText;
response.body.pipe(res);
} catch (error) { } catch (error) {
res.status(500).send('Error occurred while trying to proxy to: ' + url + ' ' + error); res.status(500).send('Error occurred while trying to proxy to: ' + url + ' ' + error);
@ -457,18 +467,9 @@ app.post('/generate', jsonParser, async function (request, response_generate) {
const response = await fetch(url, { method: 'POST', timeout: 0, ...args }); const response = await fetch(url, { method: 'POST', timeout: 0, ...args });
if (request.body.streaming) { if (request.body.streaming) {
request.socket.on('close', function () {
if (response.body instanceof Readable) response.body.destroy(); // Close the remote stream
response_generate.end(); // End the Express response
});
response.body.on('end', function () {
console.log('Streaming request finished');
response_generate.end();
});
// Pipe remote SSE stream to Express response // Pipe remote SSE stream to Express response
return response.body.pipe(response_generate); forwardFetchResponse(response, response_generate);
return;
} else { } else {
if (!response.ok) { if (!response.ok) {
const errorText = await response.text(); const errorText = await response.text();
@ -666,17 +667,7 @@ app.post('/api/textgenerationwebui/generate', jsonParser, async function (reques
if (request.body.stream) { if (request.body.stream) {
const completionsStream = await fetch(url, args); const completionsStream = await fetch(url, args);
// Pipe remote SSE stream to Express response // Pipe remote SSE stream to Express response
completionsStream.body.pipe(response_generate); forwardFetchResponse(completionsStream, response_generate);
request.socket.on('close', function () {
if (completionsStream.body instanceof Readable) completionsStream.body.destroy(); // Close the remote stream
response_generate.end(); // End the Express response
});
completionsStream.body.on('end', function () {
console.log('Streaming request finished');
response_generate.end();
});
} }
else { else {
const completionsReply = await fetch(url, args); const completionsReply = await fetch(url, args);
@ -1427,17 +1418,7 @@ async function sendClaudeRequest(request, response) {
if (request.body.stream) { if (request.body.stream) {
// Pipe remote SSE stream to Express response // Pipe remote SSE stream to Express response
generateResponse.body.pipe(response); forwardFetchResponse(generateResponse, response);
request.socket.on('close', function () {
if (generateResponse.body instanceof Readable) generateResponse.body.destroy(); // Close the remote stream
response.end(); // End the Express response
});
generateResponse.body.on('end', function () {
console.log('Streaming request finished');
response.end();
});
} else { } else {
if (!generateResponse.ok) { if (!generateResponse.ok) {
console.log(`Claude API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`); console.log(`Claude API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`);
@ -1640,11 +1621,7 @@ app.post('/generate_openai', jsonParser, function (request, response_generate_op
if (fetchResponse.ok) { if (fetchResponse.ok) {
if (request.body.stream) { if (request.body.stream) {
console.log('Streaming request in progress'); console.log('Streaming request in progress');
fetchResponse.body.pipe(response_generate_openai); forwardFetchResponse(fetchResponse, response_generate_openai);
fetchResponse.body.on('end', () => {
console.log('Streaming request finished');
response_generate_openai.end();
});
} else { } else {
let json = await fetchResponse.json(); let json = await fetchResponse.json();
response_generate_openai.send(json); response_generate_openai.send(json);

View File

@ -1,9 +1,8 @@
const fetch = require('node-fetch').default; const fetch = require('node-fetch').default;
const express = require('express'); const express = require('express');
const util = require('util'); const util = require('util');
const { Readable } = require('stream');
const { readSecret, SECRET_KEYS } = require('./secrets'); const { readSecret, SECRET_KEYS } = require('./secrets');
const { readAllChunks, extractFileFromZipBuffer } = require('../util'); const { readAllChunks, extractFileFromZipBuffer, forwardFetchResponse } = require('../util');
const { jsonParser } = require('../express-common'); const { jsonParser } = require('../express-common');
const API_NOVELAI = 'https://api.novelai.net'; const API_NOVELAI = 'https://api.novelai.net';
@ -188,17 +187,7 @@ router.post('/generate', jsonParser, async function (req, res) {
if (req.body.streaming) { if (req.body.streaming) {
// Pipe remote SSE stream to Express response // Pipe remote SSE stream to Express response
response.body.pipe(res); forwardFetchResponse(response, res);
req.socket.on('close', function () {
if (response.body instanceof Readable) response.body.destroy(); // Close the remote stream
res.end(); // End the Express response
});
response.body.on('end', function () {
console.log('Streaming request finished');
res.end();
});
} else { } else {
if (!response.ok) { if (!response.ok) {
const text = await response.text(); const text = await response.text();

View File

@ -6,6 +6,7 @@ const yauzl = require('yauzl');
const mime = require('mime-types'); const mime = require('mime-types');
const yaml = require('yaml'); const yaml = require('yaml');
const { default: simpleGit } = require('simple-git'); const { default: simpleGit } = require('simple-git');
const { Readable } = require('stream');
const { DIRECTORIES } = require('./constants'); const { DIRECTORIES } = require('./constants');
@ -346,6 +347,27 @@ function getImages(path) {
.sort(Intl.Collator().compare); .sort(Intl.Collator().compare);
} }
/**
* Pipe a fetch() response to an Express.js Response, including status code.
* @param {Response} from The Fetch API response to pipe from.
* @param {Express.Response} to The Express response to pipe to.
*/
function forwardFetchResponse(from, to) {
to.statusCode = from.status;
to.statusMessage = from.statusText;
from.body.pipe(to);
to.socket.on('close', function () {
if (from.body instanceof Readable) from.body.destroy(); // Close the remote stream
to.end(); // End the Express response
});
from.body.on('end', function () {
console.log('Streaming request finished');
to.end();
});
}
module.exports = { module.exports = {
getConfig, getConfig,
getConfigValue, getConfigValue,
@ -365,4 +387,5 @@ module.exports = {
generateTimestamp, generateTimestamp,
removeOldBackups, removeOldBackups,
getImages, getImages,
forwardFetchResponse,
}; };