Merge pull request #1482 from valadaptive/sse-stream

Refactor server-sent events parsing
This commit is contained in:
Cohee 2023-12-10 18:32:19 +02:00 committed by GitHub
commit dbd52a7994
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 266 additions and 249 deletions

View File

@ -596,7 +596,6 @@
"openrouter_model": "OR_Website",
"jailbreak_system": true,
"reverse_proxy": "",
"legacy_streaming": false,
"chat_completion_source": "openai",
"max_context_unlocked": false,
"api_url_scale": "",

View File

@ -759,19 +759,6 @@
<input type="number" id="seed_openai" name="seed_openai" class="text_pole" min="-1" max="2147483647" value="-1">
</div>
</div>
<div data-newbie-hidden class="range-block" data-source="openai,claude">
<div class="range-block-title justifyLeft">
<label for="legacy_streaming" class="checkbox_label">
<input id="legacy_streaming" type="checkbox" />
<span data-i18n="Legacy Streaming Processing">
Legacy Streaming Processing
</span>
</label>
</div>
<div class="toggle-description justifyLeft" data-i18n="Enable this if the streaming doesn't work with your proxy">
Enable this if the streaming doesn't work with your proxy.
</div>
</div>
</div>
</div>
<div id="advanced-ai-config-block" class="width100p">

View File

@ -3767,10 +3767,12 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
}
console.debug(`pushed prompt bits to itemizedPrompts array. Length is now: ${itemizedPrompts.length}`);
/** @type {Promise<any>} */
let streamingGeneratorPromise = Promise.resolve();
if (main_api == 'openai') {
if (isStreamingEnabled() && type !== 'quiet') {
streamingProcessor.generator = await sendOpenAIRequest(type, generate_data.prompt, streamingProcessor.abortController.signal);
streamingGeneratorPromise = sendOpenAIRequest(type, generate_data.prompt, streamingProcessor.abortController.signal);
}
else {
sendOpenAIRequest(type, generate_data.prompt, abortController.signal).then(onSuccess).catch(onError);
@ -3780,13 +3782,13 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
generateHorde(finalPrompt, generate_data, abortController.signal, true).then(onSuccess).catch(onError);
}
else if (main_api == 'textgenerationwebui' && isStreamingEnabled() && type !== 'quiet') {
streamingProcessor.generator = await generateTextGenWithStreaming(generate_data, streamingProcessor.abortController.signal);
streamingGeneratorPromise = generateTextGenWithStreaming(generate_data, streamingProcessor.abortController.signal);
}
else if (main_api == 'novel' && isStreamingEnabled() && type !== 'quiet') {
streamingProcessor.generator = await generateNovelWithStreaming(generate_data, streamingProcessor.abortController.signal);
streamingGeneratorPromise = generateNovelWithStreaming(generate_data, streamingProcessor.abortController.signal);
}
else if (main_api == 'kobold' && isStreamingEnabled() && type !== 'quiet') {
streamingProcessor.generator = await generateKoboldWithStreaming(generate_data, streamingProcessor.abortController.signal);
streamingGeneratorPromise = generateKoboldWithStreaming(generate_data, streamingProcessor.abortController.signal);
}
else {
try {
@ -3811,19 +3813,27 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
}
if (isStreamingEnabled() && type !== 'quiet') {
hideSwipeButtons();
let getMessage = await streamingProcessor.generate();
let messageChunk = cleanUpMessage(getMessage, isImpersonate, isContinue, false);
try {
const streamingGenerator = await streamingGeneratorPromise;
streamingProcessor.generator = streamingGenerator;
hideSwipeButtons();
let getMessage = await streamingProcessor.generate();
let messageChunk = cleanUpMessage(getMessage, isImpersonate, isContinue, false);
if (isContinue) {
getMessage = continue_mag + getMessage;
if (isContinue) {
getMessage = continue_mag + getMessage;
}
if (streamingProcessor && !streamingProcessor.isStopped && streamingProcessor.isFinished) {
await streamingProcessor.onFinishStreaming(streamingProcessor.messageId, getMessage);
streamingProcessor = null;
triggerAutoContinue(messageChunk, isImpersonate);
}
resolve();
} catch (err) {
onError(err);
}
if (streamingProcessor && !streamingProcessor.isStopped && streamingProcessor.isFinished) {
await streamingProcessor.onFinishStreaming(streamingProcessor.messageId, getMessage);
streamingProcessor = null;
triggerAutoContinue(messageChunk, isImpersonate);
}
}
async function onSuccess(data) {

View File

@ -10,6 +10,7 @@ import {
import {
power_user,
} from './power-user.js';
import EventSourceStream from './sse-stream.js';
import { getSortableDelay } from './utils.js';
export const kai_settings = {
@ -153,6 +154,24 @@ export function getKoboldGenerationData(finalPrompt, settings, maxLength, maxCon
return generate_data;
}
function tryParseStreamingError(response, decoded) {
try {
const data = JSON.parse(decoded);
if (!data) {
return;
}
if (data.error) {
toastr.error(data.error.message || response.statusText, 'KoboldAI API');
throw new Error(data);
}
}
catch {
// No JSON. Do nothing.
}
}
export async function generateKoboldWithStreaming(generate_data, signal) {
const response = await fetch('/generate', {
headers: getRequestHeaders(),
@ -160,37 +179,25 @@ export async function generateKoboldWithStreaming(generate_data, signal) {
method: 'POST',
signal: signal,
});
if (!response.ok) {
tryParseStreamingError(response, await response.text());
throw new Error(`Got response status ${response.status}`);
}
const eventStream = new EventSourceStream();
response.body.pipeThrough(eventStream);
const reader = eventStream.readable.getReader();
return async function* streamData() {
const decoder = new TextDecoder();
const reader = response.body.getReader();
let getMessage = '';
let messageBuffer = '';
let text = '';
while (true) {
const { done, value } = await reader.read();
let response = decoder.decode(value);
let eventList = [];
if (done) return;
// ReadableStream's buffer is not guaranteed to contain full SSE messages as they arrive in chunks
// We need to buffer chunks until we have one or more full messages (separated by double newlines)
messageBuffer += response;
eventList = messageBuffer.split('\n\n');
// Last element will be an empty string or a leftover partial message
messageBuffer = eventList.pop();
for (let event of eventList) {
for (let subEvent of event.split('\n')) {
if (subEvent.startsWith('data')) {
let data = JSON.parse(subEvent.substring(5));
getMessage += (data?.token || '');
yield { text: getMessage, swipes: [] };
}
}
}
if (done) {
return;
const data = JSON.parse(value.data);
if (data?.token) {
text += data.token;
}
yield { text, swipes: [] };
}
};
}

View File

@ -10,6 +10,7 @@ import {
import { getCfgPrompt } from './cfg-scale.js';
import { MAX_CONTEXT_DEFAULT, MAX_RESPONSE_DEFAULT } from './power-user.js';
import { getTextTokens, tokenizers } from './tokenizers.js';
import EventSourceStream from './sse-stream.js';
import {
getSortableDelay,
getStringHash,
@ -663,7 +664,7 @@ export function adjustNovelInstructionPrompt(prompt) {
return stripedPrompt;
}
function tryParseStreamingError(decoded) {
function tryParseStreamingError(response, decoded) {
try {
const data = JSON.parse(decoded);
@ -671,8 +672,8 @@ function tryParseStreamingError(decoded) {
return;
}
if (data.message && data.statusCode >= 400) {
toastr.error(data.message, 'Error');
if (data.message || data.error) {
toastr.error(data.message || data.error?.message || response.statusText, 'NovelAI API');
throw new Error(data);
}
}
@ -690,39 +691,27 @@ export async function generateNovelWithStreaming(generate_data, signal) {
method: 'POST',
signal: signal,
});
if (!response.ok) {
tryParseStreamingError(response, await response.text());
throw new Error(`Got response status ${response.status}`);
}
const eventStream = new EventSourceStream();
response.body.pipeThrough(eventStream);
const reader = eventStream.readable.getReader();
return async function* streamData() {
const decoder = new TextDecoder();
const reader = response.body.getReader();
let getMessage = '';
let messageBuffer = '';
let text = '';
while (true) {
const { done, value } = await reader.read();
let decoded = decoder.decode(value);
let eventList = [];
if (done) return;
tryParseStreamingError(decoded);
const data = JSON.parse(value.data);
// ReadableStream's buffer is not guaranteed to contain full SSE messages as they arrive in chunks
// We need to buffer chunks until we have one or more full messages (separated by double newlines)
messageBuffer += decoded;
eventList = messageBuffer.split('\n\n');
// Last element will be an empty string or a leftover partial message
messageBuffer = eventList.pop();
for (let event of eventList) {
for (let subEvent of event.split('\n')) {
if (subEvent.startsWith('data')) {
let data = JSON.parse(subEvent.substring(5));
getMessage += (data?.token || '');
yield { text: getMessage, swipes: [] };
}
}
if (data.token) {
text += data.token;
}
if (done) {
return;
}
yield { text, swipes: [] };
}
};
}

View File

@ -44,6 +44,7 @@ import {
import { getCustomStoppingStrings, persona_description_positions, power_user } from './power-user.js';
import { SECRET_KEYS, secret_state, writeSecret } from './secrets.js';
import EventSourceStream from './sse-stream.js';
import {
delay,
download,
@ -215,7 +216,6 @@ const default_settings = {
openrouter_sort_models: 'alphabetically',
jailbreak_system: false,
reverse_proxy: '',
legacy_streaming: false,
chat_completion_source: chat_completion_sources.OPENAI,
max_context_unlocked: false,
api_url_scale: '',
@ -269,7 +269,6 @@ const oai_settings = {
openrouter_sort_models: 'alphabetically',
jailbreak_system: false,
reverse_proxy: '',
legacy_streaming: false,
chat_completion_source: chat_completion_sources.OPENAI,
max_context_unlocked: false,
api_url_scale: '',
@ -1124,7 +1123,7 @@ function tryParseStreamingError(response, decoded) {
checkQuotaError(data);
if (data.error) {
toastr.error(data.error.message || response.statusText, 'API returned an error');
toastr.error(data.error.message || response.statusText, 'Chat Completion API');
throw new Error(data);
}
}
@ -1564,58 +1563,28 @@ async function sendOpenAIRequest(type, messages, signal) {
signal: signal,
});
if (!response.ok) {
tryParseStreamingError(response, await response.text());
throw new Error(`Got response status ${response.status}`);
}
if (stream) {
const eventStream = new EventSourceStream();
response.body.pipeThrough(eventStream);
const reader = eventStream.readable.getReader();
return async function* streamData() {
const decoder = new TextDecoder();
const reader = response.body.getReader();
let getMessage = '';
let messageBuffer = '';
let text = '';
while (true) {
const { done, value } = await reader.read();
let decoded = decoder.decode(value);
if (done) return;
if (value.data === '[DONE]') return;
// Claude's streaming SSE messages are separated by \r
if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) {
decoded = decoded.replace(/\r/g, '');
}
tryParseStreamingError(response, value.data);
tryParseStreamingError(response, decoded);
// the first and last messages are undefined, protect against that
text += getStreamingReply(JSON.parse(value.data));
let eventList = [];
// ReadableStream's buffer is not guaranteed to contain full SSE messages as they arrive in chunks
// We need to buffer chunks until we have one or more full messages (separated by double newlines)
if (!oai_settings.legacy_streaming) {
messageBuffer += decoded;
eventList = messageBuffer.split('\n\n');
// Last element will be an empty string or a leftover partial message
messageBuffer = eventList.pop();
} else {
eventList = decoded.split('\n');
}
for (let event of eventList) {
if (event.startsWith('event: completion')) {
event = event.split('\n')[1];
}
if (typeof event !== 'string' || !event.length)
continue;
if (!event.startsWith('data'))
continue;
if (event == 'data: [DONE]') {
return;
}
let data = JSON.parse(event.substring(6));
// the first and last messages are undefined, protect against that
getMessage = getStreamingReply(getMessage, data);
yield { text: getMessage, swipes: [] };
}
if (done) {
return;
}
yield { text, swipes: [] };
}
};
}
@ -1633,13 +1602,12 @@ async function sendOpenAIRequest(type, messages, signal) {
}
}
function getStreamingReply(getMessage, data) {
function getStreamingReply(data) {
if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) {
getMessage += data?.completion || '';
return data?.completion || '';
} else {
getMessage += data.choices[0]?.delta?.content || data.choices[0]?.message?.content || data.choices[0]?.text || '';
return data.choices[0]?.delta?.content || data.choices[0]?.message?.content || data.choices[0]?.text || '';
}
return getMessage;
}
function handleWindowError(err) {
@ -2307,7 +2275,6 @@ function loadOpenAISettings(data, settings) {
oai_settings.openai_max_tokens = settings.openai_max_tokens ?? default_settings.openai_max_tokens;
oai_settings.bias_preset_selected = settings.bias_preset_selected ?? default_settings.bias_preset_selected;
oai_settings.bias_presets = settings.bias_presets ?? default_settings.bias_presets;
oai_settings.legacy_streaming = settings.legacy_streaming ?? default_settings.legacy_streaming;
oai_settings.max_context_unlocked = settings.max_context_unlocked ?? default_settings.max_context_unlocked;
oai_settings.send_if_empty = settings.send_if_empty ?? default_settings.send_if_empty;
oai_settings.wi_format = settings.wi_format ?? default_settings.wi_format;
@ -2370,7 +2337,6 @@ function loadOpenAISettings(data, settings) {
$('#wrap_in_quotes').prop('checked', oai_settings.wrap_in_quotes);
$('#names_in_completion').prop('checked', oai_settings.names_in_completion);
$('#jailbreak_system').prop('checked', oai_settings.jailbreak_system);
$('#legacy_streaming').prop('checked', oai_settings.legacy_streaming);
$('#openai_show_external_models').prop('checked', oai_settings.show_external_models);
$('#openai_external_category').toggle(oai_settings.show_external_models);
$('#use_ai21_tokenizer').prop('checked', oai_settings.use_ai21_tokenizer);
@ -2575,7 +2541,6 @@ async function saveOpenAIPreset(name, settings, triggerUi = true) {
bias_preset_selected: settings.bias_preset_selected,
reverse_proxy: settings.reverse_proxy,
proxy_password: settings.proxy_password,
legacy_streaming: settings.legacy_streaming,
max_context_unlocked: settings.max_context_unlocked,
wi_format: settings.wi_format,
scenario_format: settings.scenario_format,
@ -2936,7 +2901,6 @@ function onSettingsPresetChange() {
continue_nudge_prompt: ['#continue_nudge_prompt_textarea', 'continue_nudge_prompt', false],
bias_preset_selected: ['#openai_logit_bias_preset', 'bias_preset_selected', false],
reverse_proxy: ['#openai_reverse_proxy', 'reverse_proxy', false],
legacy_streaming: ['#legacy_streaming', 'legacy_streaming', true],
wi_format: ['#wi_format_textarea', 'wi_format', false],
scenario_format: ['#scenario_format_textarea', 'scenario_format', false],
personality_format: ['#personality_format_textarea', 'personality_format', false],
@ -3692,11 +3656,6 @@ $(document).ready(async function () {
saveSettingsDebounced();
});
$('#legacy_streaming').on('input', function () {
oai_settings.legacy_streaming = !!$(this).prop('checked');
saveSettingsDebounced();
});
$('#openai_bypass_status_check').on('input', function () {
oai_settings.bypass_status_check = !!$(this).prop('checked');
getStatusOpen();

View File

@ -0,0 +1,77 @@
/**
* A stream which handles Server-Sent Events from a binary ReadableStream like you get from the fetch API.
*/
class EventSourceStream {
constructor() {
const decoder = new TextDecoderStream('utf-8');
let streamBuffer = '';
let lastEventId = '';
function processChunk(controller) {
// Events are separated by two newlines
const events = streamBuffer.split(/\r\n\r\n|\r\r|\n\n/g);
if (events.length === 0) return;
// The leftover text to remain in the buffer is whatever doesn't have two newlines after it. If the buffer ended
// with two newlines, this will be an empty string.
streamBuffer = events.pop();
for (const eventChunk of events) {
let eventType = '';
// Split up by single newlines.
const lines = eventChunk.split(/\n|\r|\r\n/g);
let eventData = '';
for (const line of lines) {
const lineMatch = /([^:]+)(?:: ?(.*))?/.exec(line);
if (lineMatch) {
const field = lineMatch[1];
const value = lineMatch[2] || '';
switch (field) {
case 'event':
eventType = value;
break;
case 'data':
eventData += value;
eventData += '\n';
break;
case 'id':
// The ID field cannot contain null, per the spec
if (!value.includes('\0')) lastEventId = value;
break;
// We do nothing for the `delay` type, and other types are explicitly ignored
}
}
}
// https://html.spec.whatwg.org/multipage/server-sent-events.html#dispatchMessage
// Skip the event if the data buffer is the empty string.
if (eventData === '') continue;
if (eventData[eventData.length - 1] === '\n') {
eventData = eventData.slice(0, -1);
}
// Trim the *last* trailing newline only.
const event = new MessageEvent(eventType || 'message', { data: eventData, lastEventId });
controller.enqueue(event);
}
}
const sseStream = new TransformStream({
transform(chunk, controller) {
streamBuffer += chunk;
processChunk(controller);
},
});
decoder.readable.pipeThrough(sseStream);
this.readable = sseStream.readable;
this.writable = decoder.writable;
}
}
export default EventSourceStream;

View File

@ -14,6 +14,7 @@ import {
power_user,
registerDebugFunction,
} from './power-user.js';
import EventSourceStream from './sse-stream.js';
import { SENTENCEPIECE_TOKENIZERS, getTextTokens, tokenizers } from './tokenizers.js';
import { getSortableDelay, onlyUnique } from './utils.js';
@ -476,68 +477,50 @@ async function generateTextGenWithStreaming(generate_data, signal) {
signal: signal,
});
if (!response.ok) {
tryParseStreamingError(response, await response.text());
throw new Error(`Got response status ${response.status}`);
}
const eventStream = new EventSourceStream();
response.body.pipeThrough(eventStream);
const reader = eventStream.readable.getReader();
return async function* streamData() {
const decoder = new TextDecoder();
const reader = response.body.getReader();
let getMessage = '';
let messageBuffer = '';
let text = '';
const swipes = [];
while (true) {
const { done, value } = await reader.read();
// We don't want carriage returns in our messages
let response = decoder.decode(value).replace(/\r/g, '');
if (done) return;
if (value.data === '[DONE]') return;
tryParseStreamingError(response);
tryParseStreamingError(response, value.data);
let eventList = [];
let data = JSON.parse(value.data);
messageBuffer += response;
eventList = messageBuffer.split('\n\n');
// Last element will be an empty string or a leftover partial message
messageBuffer = eventList.pop();
for (let event of eventList) {
if (event.startsWith('event: completion')) {
event = event.split('\n')[1];
}
if (typeof event !== 'string' || !event.length)
continue;
if (!event.startsWith('data'))
continue;
if (event == 'data: [DONE]') {
return;
}
let data = JSON.parse(event.substring(6));
if (data?.choices[0]?.index > 0) {
const swipeIndex = data.choices[0].index - 1;
swipes[swipeIndex] = (swipes[swipeIndex] || '') + data.choices[0].text;
} else {
getMessage += data?.choices[0]?.text || '';
}
yield { text: getMessage, swipes: swipes };
if (data?.choices[0]?.index > 0) {
const swipeIndex = data.choices[0].index - 1;
swipes[swipeIndex] = (swipes[swipeIndex] || '') + data.choices[0].text;
} else {
text += data?.choices[0]?.text || '';
}
if (done) {
return;
}
yield { text, swipes };
}
};
}
/**
* Parses errors in streaming responses and displays them in toastr.
* @param {string} response - Response from the server.
* @param {Response} response - Response from the server.
* @param {string} decoded - Decoded response body.
* @returns {void} Nothing.
*/
function tryParseStreamingError(response) {
function tryParseStreamingError(response, decoded) {
let data = {};
try {
data = JSON.parse(response);
data = JSON.parse(decoded);
} catch {
// No JSON. Do nothing.
}
@ -545,7 +528,7 @@ function tryParseStreamingError(response) {
const message = data?.error?.message || data?.message;
if (message) {
toastr.error(message, 'API Error');
toastr.error(message, 'Text Completion API');
throw new Error(message);
}
}

View File

@ -7,7 +7,6 @@ const http = require('http');
const https = require('https');
const path = require('path');
const util = require('util');
const { Readable } = require('stream');
// cli/fs related library imports
const open = require('open');
@ -45,7 +44,20 @@ const basicAuthMiddleware = require('./src/middleware/basicAuthMiddleware');
const { jsonParser, urlencodedParser } = require('./src/express-common.js');
const contentManager = require('./src/endpoints/content-manager');
const { readSecret, migrateSecrets, SECRET_KEYS } = require('./src/endpoints/secrets');
const { delay, getVersion, getConfigValue, color, uuidv4, tryParse, clientRelativePath, removeFileExtension, generateTimestamp, removeOldBackups, getImages } = require('./src/util');
const {
delay,
getVersion,
getConfigValue,
color,
uuidv4,
tryParse,
clientRelativePath,
removeFileExtension,
generateTimestamp,
removeOldBackups,
getImages,
forwardFetchResponse,
} = require('./src/util');
const { ensureThumbnailCache } = require('./src/endpoints/thumbnails');
const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS, getSentencepiceTokenizer, sentencepieceTokenizers } = require('./src/endpoints/tokenizers');
const { convertClaudePrompt } = require('./src/chat-completion');
@ -244,9 +256,7 @@ if (getConfigValue('enableCorsProxy', false) || cliArguments.corsProxy) {
});
// Copy over relevant response params to the proxy response
res.statusCode = response.status;
res.statusMessage = response.statusText;
response.body.pipe(res);
forwardFetchResponse(response, res);
} catch (error) {
res.status(500).send('Error occurred while trying to proxy to: ' + url + ' ' + error);
@ -394,18 +404,9 @@ app.post('/generate', jsonParser, async function (request, response_generate) {
const response = await fetch(url, { method: 'POST', timeout: 0, ...args });
if (request.body.streaming) {
request.socket.on('close', function () {
if (response.body instanceof Readable) response.body.destroy(); // Close the remote stream
response_generate.end(); // End the Express response
});
response.body.on('end', function () {
console.log('Streaming request finished');
response_generate.end();
});
// Pipe remote SSE stream to Express response
return response.body.pipe(response_generate);
forwardFetchResponse(response, response_generate);
return;
} else {
if (!response.ok) {
const errorText = await response.text();
@ -603,17 +604,7 @@ app.post('/api/textgenerationwebui/generate', jsonParser, async function (reques
if (request.body.stream) {
const completionsStream = await fetch(url, args);
// Pipe remote SSE stream to Express response
completionsStream.body.pipe(response_generate);
request.socket.on('close', function () {
if (completionsStream.body instanceof Readable) completionsStream.body.destroy(); // Close the remote stream
response_generate.end(); // End the Express response
});
completionsStream.body.on('end', function () {
console.log('Streaming request finished');
response_generate.end();
});
forwardFetchResponse(completionsStream, response_generate);
}
else {
const completionsReply = await fetch(url, args);
@ -1367,17 +1358,7 @@ async function sendClaudeRequest(request, response) {
if (request.body.stream) {
// Pipe remote SSE stream to Express response
generateResponse.body.pipe(response);
request.socket.on('close', function () {
if (generateResponse.body instanceof Readable) generateResponse.body.destroy(); // Close the remote stream
response.end(); // End the Express response
});
generateResponse.body.on('end', function () {
console.log('Streaming request finished');
response.end();
});
forwardFetchResponse(generateResponse, response);
} else {
if (!generateResponse.ok) {
console.log(`Claude API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`);
@ -1579,20 +1560,17 @@ app.post('/generate_openai', jsonParser, function (request, response_generate_op
try {
const fetchResponse = await fetch(endpointUrl, config);
if (request.body.stream) {
console.log('Streaming request in progress');
forwardFetchResponse(fetchResponse, response_generate_openai);
return;
}
if (fetchResponse.ok) {
if (request.body.stream) {
console.log('Streaming request in progress');
fetchResponse.body.pipe(response_generate_openai);
fetchResponse.body.on('end', () => {
console.log('Streaming request finished');
response_generate_openai.end();
});
} else {
let json = await fetchResponse.json();
response_generate_openai.send(json);
console.log(json);
console.log(json?.choices[0]?.message);
}
let json = await fetchResponse.json();
response_generate_openai.send(json);
console.log(json);
console.log(json?.choices[0]?.message);
} else if (fetchResponse.status === 429 && retries > 0) {
console.log(`Out of quota, retrying in ${Math.round(timeout / 1000)}s`);
setTimeout(() => {

View File

@ -1,9 +1,8 @@
const fetch = require('node-fetch').default;
const express = require('express');
const util = require('util');
const { Readable } = require('stream');
const { readSecret, SECRET_KEYS } = require('./secrets');
const { readAllChunks, extractFileFromZipBuffer } = require('../util');
const { readAllChunks, extractFileFromZipBuffer, forwardFetchResponse } = require('../util');
const { jsonParser } = require('../express-common');
const API_NOVELAI = 'https://api.novelai.net';
@ -190,17 +189,7 @@ router.post('/generate', jsonParser, async function (req, res) {
if (req.body.streaming) {
// Pipe remote SSE stream to Express response
response.body.pipe(res);
req.socket.on('close', function () {
if (response.body instanceof Readable) response.body.destroy(); // Close the remote stream
res.end(); // End the Express response
});
response.body.on('end', function () {
console.log('Streaming request finished');
res.end();
});
forwardFetchResponse(response, res);
} else {
if (!response.ok) {
const text = await response.text();

View File

@ -6,6 +6,7 @@ const yauzl = require('yauzl');
const mime = require('mime-types');
const yaml = require('yaml');
const { default: simpleGit } = require('simple-git');
const { Readable } = require('stream');
const { DIRECTORIES } = require('./constants');
@ -346,6 +347,43 @@ function getImages(path) {
.sort(Intl.Collator().compare);
}
/**
* Pipe a fetch() response to an Express.js Response, including status code.
* @param {Response} from The Fetch API response to pipe from.
* @param {Express.Response} to The Express response to pipe to.
*/
function forwardFetchResponse(from, to) {
let statusCode = from.status;
let statusText = from.statusText;
if (!from.ok) {
console.log(`Streaming request failed with status ${statusCode} ${statusText}`);
}
// Avoid sending 401 responses as they reset the client Basic auth.
// This can produce an interesting artifact as "400 Unauthorized", but it's not out of spec.
// https://www.rfc-editor.org/rfc/rfc9110.html#name-overview-of-status-codes
// "The reason phrases listed here are only recommendations -- they can be replaced by local
// equivalents or left out altogether without affecting the protocol."
if (statusCode === 401) {
statusCode = 400;
}
to.statusCode = statusCode;
to.statusMessage = statusText;
from.body.pipe(to);
to.socket.on('close', function () {
if (from.body instanceof Readable) from.body.destroy(); // Close the remote stream
to.end(); // End the Express response
});
from.body.on('end', function () {
console.log('Streaming request finished');
to.end();
});
}
module.exports = {
getConfig,
getConfigValue,
@ -365,4 +403,5 @@ module.exports = {
generateTimestamp,
removeOldBackups,
getImages,
forwardFetchResponse,
};