Fix streaming processor error handler hooks

This commit is contained in:
Cohee 2023-12-08 02:01:08 +02:00
parent 055d6c4337
commit b0e7b73a32
6 changed files with 42 additions and 25 deletions

View File

@ -2730,6 +2730,10 @@ class StreamingProcessor {
this.onErrorStreaming(); this.onErrorStreaming();
} }
hook(generatorFn) {
this.generator = generatorFn;
}
*nullStreamingGeneration() { *nullStreamingGeneration() {
throw new Error('Generation function for streaming is not hooked up'); throw new Error('Generation function for streaming is not hooked up');
} }
@ -3722,10 +3726,14 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
} }
console.debug(`pushed prompt bits to itemizedPrompts array. Length is now: ${itemizedPrompts.length}`); console.debug(`pushed prompt bits to itemizedPrompts array. Length is now: ${itemizedPrompts.length}`);
/** @type {Promise<any>} */
let streamingHookPromise = Promise.resolve();
if (main_api == 'openai') { if (main_api == 'openai') {
if (isStreamingEnabled() && type !== 'quiet') { if (isStreamingEnabled() && type !== 'quiet') {
streamingProcessor.generator = await sendOpenAIRequest(type, generate_data.prompt, streamingProcessor.abortController.signal); streamingHookPromise = sendOpenAIRequest(type, generate_data.prompt, streamingProcessor.abortController.signal)
.then(fn => streamingProcessor.hook(fn))
.catch(onError);
} }
else { else {
sendOpenAIRequest(type, generate_data.prompt, abortController.signal).then(onSuccess).catch(onError); sendOpenAIRequest(type, generate_data.prompt, abortController.signal).then(onSuccess).catch(onError);
@ -3735,13 +3743,19 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
generateHorde(finalPrompt, generate_data, abortController.signal, true).then(onSuccess).catch(onError); generateHorde(finalPrompt, generate_data, abortController.signal, true).then(onSuccess).catch(onError);
} }
else if (main_api == 'textgenerationwebui' && isStreamingEnabled() && type !== 'quiet') { else if (main_api == 'textgenerationwebui' && isStreamingEnabled() && type !== 'quiet') {
streamingProcessor.generator = await generateTextGenWithStreaming(generate_data, streamingProcessor.abortController.signal); streamingHookPromise = generateTextGenWithStreaming(generate_data, streamingProcessor.abortController.signal)
.then(fn => streamingProcessor.hook(fn))
.catch(onError);
} }
else if (main_api == 'novel' && isStreamingEnabled() && type !== 'quiet') { else if (main_api == 'novel' && isStreamingEnabled() && type !== 'quiet') {
streamingProcessor.generator = await generateNovelWithStreaming(generate_data, streamingProcessor.abortController.signal); streamingHookPromise = generateNovelWithStreaming(generate_data, streamingProcessor.abortController.signal)
.then(fn => streamingProcessor.hook(fn))
.catch(onError);
} }
else if (main_api == 'kobold' && isStreamingEnabled() && type !== 'quiet') { else if (main_api == 'kobold' && isStreamingEnabled() && type !== 'quiet') {
streamingProcessor.generator = await generateKoboldWithStreaming(generate_data, streamingProcessor.abortController.signal); streamingHookPromise = generateKoboldWithStreaming(generate_data, streamingProcessor.abortController.signal)
.then(fn => streamingProcessor.hook(fn))
.catch(onError);
} }
else { else {
try { try {
@ -3767,6 +3781,7 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
if (isStreamingEnabled() && type !== 'quiet') { if (isStreamingEnabled() && type !== 'quiet') {
hideSwipeButtons(); hideSwipeButtons();
await streamingHookPromise;
let getMessage = await streamingProcessor.generate(); let getMessage = await streamingProcessor.generate();
let messageChunk = cleanUpMessage(getMessage, isImpersonate, isContinue, false); let messageChunk = cleanUpMessage(getMessage, isImpersonate, isContinue, false);

View File

@ -163,7 +163,7 @@ function tryParseStreamingError(response, decoded) {
} }
if (data.error) { if (data.error) {
toastr.error(data.error.message || response.statusText, 'API returned an error'); toastr.error(data.error.message || response.statusText, 'KoboldAI API');
throw new Error(data); throw new Error(data);
} }
} }
@ -180,7 +180,7 @@ export async function generateKoboldWithStreaming(generate_data, signal) {
signal: signal, signal: signal,
}); });
if (!response.ok) { if (!response.ok) {
tryParseStreamingError(response, await response.body.text()); tryParseStreamingError(response, await response.text());
throw new Error(`Got response status ${response.status}`); throw new Error(`Got response status ${response.status}`);
} }
const eventStream = new EventSourceStream(); const eventStream = new EventSourceStream();

View File

@ -672,8 +672,8 @@ function tryParseStreamingError(response, decoded) {
return; return;
} }
if (data.error) { if (data.message || data.error) {
toastr.error(data.error.message || response.statusText, 'API returned an error'); toastr.error(data.message || data.error?.message || response.statusText, 'NovelAI API');
throw new Error(data); throw new Error(data);
} }
} }
@ -692,7 +692,7 @@ export async function generateNovelWithStreaming(generate_data, signal) {
signal: signal, signal: signal,
}); });
if (!response.ok) { if (!response.ok) {
tryParseStreamingError(response, await response.body.text()); tryParseStreamingError(response, await response.text());
throw new Error(`Got response status ${response.status}`); throw new Error(`Got response status ${response.status}`);
} }
const eventStream = new EventSourceStream(); const eventStream = new EventSourceStream();

View File

@ -1123,7 +1123,7 @@ function tryParseStreamingError(response, decoded) {
checkQuotaError(data); checkQuotaError(data);
if (data.error) { if (data.error) {
toastr.error(data.error.message || response.statusText, 'API returned an error'); toastr.error(data.error.message || response.statusText, 'Chat Completion API');
throw new Error(data); throw new Error(data);
} }
} }
@ -1564,7 +1564,7 @@ async function sendOpenAIRequest(type, messages, signal) {
}); });
if (!response.ok) { if (!response.ok) {
tryParseStreamingError(response, await response.body.text()); tryParseStreamingError(response, await response.text());
throw new Error(`Got response status ${response.status}`); throw new Error(`Got response status ${response.status}`);
} }

View File

@ -478,7 +478,7 @@ async function generateTextGenWithStreaming(generate_data, signal) {
}); });
if (!response.ok) { if (!response.ok) {
tryParseStreamingError(response, await response.body.text()); tryParseStreamingError(response, await response.text());
throw new Error(`Got response status ${response.status}`); throw new Error(`Got response status ${response.status}`);
} }
@ -512,14 +512,15 @@ async function generateTextGenWithStreaming(generate_data, signal) {
/** /**
* Parses errors in streaming responses and displays them in toastr. * Parses errors in streaming responses and displays them in toastr.
* @param {string} response - Response from the server. * @param {Response} response - Response from the server.
* @param {string} decoded - Decoded response body.
* @returns {void} Nothing. * @returns {void} Nothing.
*/ */
function tryParseStreamingError(response) { function tryParseStreamingError(response, decoded) {
let data = {}; let data = {};
try { try {
data = JSON.parse(response); data = JSON.parse(decoded);
} catch { } catch {
// No JSON. Do nothing. // No JSON. Do nothing.
} }
@ -527,7 +528,7 @@ function tryParseStreamingError(response) {
const message = data?.error?.message || data?.message; const message = data?.error?.message || data?.message;
if (message) { if (message) {
toastr.error(message, 'API Error'); toastr.error(message, 'Text Completion API');
throw new Error(message); throw new Error(message);
} }
} }

View File

@ -1618,16 +1618,17 @@ app.post('/generate_openai', jsonParser, function (request, response_generate_op
try { try {
const fetchResponse = await fetch(endpointUrl, config); const fetchResponse = await fetch(endpointUrl, config);
if (request.body.stream) {
console.log('Streaming request in progress');
forwardFetchResponse(fetchResponse, response_generate_openai);
return;
}
if (fetchResponse.ok) { if (fetchResponse.ok) {
if (request.body.stream) { let json = await fetchResponse.json();
console.log('Streaming request in progress'); response_generate_openai.send(json);
forwardFetchResponse(fetchResponse, response_generate_openai); console.log(json);
} else { console.log(json?.choices[0]?.message);
let json = await fetchResponse.json();
response_generate_openai.send(json);
console.log(json);
console.log(json?.choices[0]?.message);
}
} else if (fetchResponse.status === 429 && retries > 0) { } else if (fetchResponse.status === 429 && retries > 0) {
console.log(`Out of quota, retrying in ${Math.round(timeout / 1000)}s`); console.log(`Out of quota, retrying in ${Math.round(timeout / 1000)}s`);
setTimeout(() => { setTimeout(() => {