diff --git a/public/scripts/kai-settings.js b/public/scripts/kai-settings.js index 4659b1bd0..4872f1536 100644 --- a/public/scripts/kai-settings.js +++ b/public/scripts/kai-settings.js @@ -10,6 +10,7 @@ import { import { power_user, } from './power-user.js'; +import EventSourceStream from './sse-stream.js'; import { getSortableDelay } from './utils.js'; export const kai_settings = { @@ -160,37 +161,21 @@ export async function generateKoboldWithStreaming(generate_data, signal) { method: 'POST', signal: signal, }); + const eventStream = new EventSourceStream(); + response.body.pipeThrough(eventStream); + const reader = eventStream.readable.getReader(); return async function* streamData() { - const decoder = new TextDecoder(); - const reader = response.body.getReader(); - let getMessage = ''; - let messageBuffer = ''; + let text = ''; while (true) { const { done, value } = await reader.read(); - let response = decoder.decode(value); - let eventList = []; + if (done) return; - // ReadableStream's buffer is not guaranteed to contain full SSE messages as they arrive in chunks - // We need to buffer chunks until we have one or more full messages (separated by double newlines) - messageBuffer += response; - eventList = messageBuffer.split('\n\n'); - // Last element will be an empty string or a leftover partial message - messageBuffer = eventList.pop(); - - for (let event of eventList) { - for (let subEvent of event.split('\n')) { - if (subEvent.startsWith('data')) { - let data = JSON.parse(subEvent.substring(5)); - getMessage += (data?.token || ''); - yield { text: getMessage, swipes: [] }; - } - } - } - - if (done) { - return; + const data = JSON.parse(value.data); + if (data?.token) { + text += data.token; } + yield { text, swipes: [] }; } }; } diff --git a/public/scripts/nai-settings.js b/public/scripts/nai-settings.js index 0456d6216..024ec5850 100644 --- a/public/scripts/nai-settings.js +++ b/public/scripts/nai-settings.js @@ -10,6 +10,7 @@ import { import { getCfgPrompt } from './cfg-scale.js'; import { MAX_CONTEXT_DEFAULT, MAX_RESPONSE_DEFAULT } from './power-user.js'; import { getTextTokens, tokenizers } from './tokenizers.js'; +import EventSourceStream from './sse-stream.js'; import { getSortableDelay, getStringHash, @@ -663,24 +664,6 @@ export function adjustNovelInstructionPrompt(prompt) { return stripedPrompt; } -function tryParseStreamingError(decoded) { - try { - const data = JSON.parse(decoded); - - if (!data) { - return; - } - - if (data.message && data.statusCode >= 400) { - toastr.error(data.message, 'Error'); - throw new Error(data); - } - } - catch { - // No JSON. Do nothing. - } -} - export async function generateNovelWithStreaming(generate_data, signal) { generate_data.streaming = nai_settings.streaming_novel; @@ -690,39 +673,27 @@ export async function generateNovelWithStreaming(generate_data, signal) { method: 'POST', signal: signal, }); + const eventStream = new EventSourceStream(); + response.body.pipeThrough(eventStream); + const reader = eventStream.readable.getReader(); return async function* streamData() { - const decoder = new TextDecoder(); - const reader = response.body.getReader(); - let getMessage = ''; - let messageBuffer = ''; + let text = ''; while (true) { const { done, value } = await reader.read(); - let decoded = decoder.decode(value); - let eventList = []; + if (done) return; - tryParseStreamingError(decoded); - - // ReadableStream's buffer is not guaranteed to contain full SSE messages as they arrive in chunks - // We need to buffer chunks until we have one or more full messages (separated by double newlines) - messageBuffer += decoded; - eventList = messageBuffer.split('\n\n'); - // Last element will be an empty string or a leftover partial message - messageBuffer = eventList.pop(); - - for (let event of eventList) { - for (let subEvent of event.split('\n')) { - if (subEvent.startsWith('data')) { - let data = JSON.parse(subEvent.substring(5)); - getMessage += (data?.token || ''); - yield { text: getMessage, swipes: [] }; - } - } + const data = JSON.parse(value.data); + if (data.message && data.statusCode >= 400) { + toastr.error(data.message, 'Error'); + throw new Error(data); } - if (done) { - return; + if (data.token) { + text += data.token; } + + yield { text, swipes: [] }; } }; } diff --git a/public/scripts/openai.js b/public/scripts/openai.js index 30f4c6e29..1758a46fb 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -44,6 +44,7 @@ import { import { getCustomStoppingStrings, persona_description_positions, power_user } from './power-user.js'; import { SECRET_KEYS, secret_state, writeSecret } from './secrets.js'; +import EventSourceStream from './sse-stream.js'; import { delay, download, @@ -1565,57 +1566,22 @@ async function sendOpenAIRequest(type, messages, signal) { }); if (stream) { + const eventStream = new EventSourceStream(); + response.body.pipeThrough(eventStream); + const reader = eventStream.readable.getReader(); return async function* streamData() { - const decoder = new TextDecoder(); - const reader = response.body.getReader(); - let getMessage = ''; - let messageBuffer = ''; + let text = ''; while (true) { const { done, value } = await reader.read(); - let decoded = decoder.decode(value); + if (done) return; + if (value.data === '[DONE]') return; - // Claude's streaming SSE messages are separated by \r - if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) { - decoded = decoded.replace(/\r/g, ''); - } + tryParseStreamingError(response, value.data); - tryParseStreamingError(response, decoded); + // the first and last messages are undefined, protect against that + text += getStreamingReply(JSON.parse(value.data)); - let eventList = []; - - // ReadableStream's buffer is not guaranteed to contain full SSE messages as they arrive in chunks - // We need to buffer chunks until we have one or more full messages (separated by double newlines) - if (!oai_settings.legacy_streaming) { - messageBuffer += decoded; - eventList = messageBuffer.split('\n\n'); - // Last element will be an empty string or a leftover partial message - messageBuffer = eventList.pop(); - } else { - eventList = decoded.split('\n'); - } - - for (let event of eventList) { - if (event.startsWith('event: completion')) { - event = event.split('\n')[1]; - } - - if (typeof event !== 'string' || !event.length) - continue; - - if (!event.startsWith('data')) - continue; - if (event == 'data: [DONE]') { - return; - } - let data = JSON.parse(event.substring(6)); - // the first and last messages are undefined, protect against that - getMessage = getStreamingReply(getMessage, data); - yield { text: getMessage, swipes: [] }; - } - - if (done) { - return; - } + yield { text, swipes: [] }; } }; } @@ -1633,13 +1599,12 @@ async function sendOpenAIRequest(type, messages, signal) { } } -function getStreamingReply(getMessage, data) { +function getStreamingReply(data) { if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) { - getMessage += data?.completion || ''; + return data?.completion || ''; } else { - getMessage += data.choices[0]?.delta?.content || data.choices[0]?.message?.content || data.choices[0]?.text || ''; + return data.choices[0]?.delta?.content || data.choices[0]?.message?.content || data.choices[0]?.text || ''; } - return getMessage; } function handleWindowError(err) { diff --git a/public/scripts/sse-stream.js b/public/scripts/sse-stream.js new file mode 100644 index 000000000..e50bf55f3 --- /dev/null +++ b/public/scripts/sse-stream.js @@ -0,0 +1,105 @@ +/** + * A stream which handles Server-Sent Events from a binary ReadableStream like you get from the fetch API. + */ +class EventSourceStream { + constructor() { + const decoder = new TextDecoderStream('utf-8', { ignoreBOM: true }); + + let streamBuffer = ''; + + let dataBuffer = ''; + let eventType = 'message'; + let lastEventId = ''; + + // https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream Parses a line from the + // event stream. This is hard to read, so here's how it works: The first group matches either a field (field + // name, optional (colon, value)) or a comment (colon, text). That group is optional, and is followed by a group + // which matches a newline. This means that: The only *capturing* groups are the field, field value, comment, + // and newline. This lets us determine what the line is by which capture groups are filled in. The field and + // value groups being present means it's a field, the comment group being present means it's a comment, and + // neither means it's a blank line. This is best viewed in RegExr if you value your sanity. + const parserRegex = /(?:(?:([^\r\n:]+)(?:: ?([^\r\n]*)?)?)|(:[^\r\n]*))?(\r\n|\r|\n)/y; + + function processChunk(controller, isLastChunk) { + while (parserRegex.lastIndex < streamBuffer.length) { + const lastLastIndex = parserRegex.lastIndex; + const matchResult = parserRegex.exec(streamBuffer); + // We need to wait for more data to come in + if (!matchResult) { + if (lastLastIndex !== 0) { + // Slice off what we've successfully parsed so far. lastIndex is set to 0 if there's no match, + // so it'll be set to start off here. + streamBuffer = streamBuffer.slice(lastLastIndex); + } + return; + } + + const field = matchResult[1]; + const value = matchResult[2]; + const comment = matchResult[3]; + const newline = matchResult[4]; + // Corner case: if the last character in the buffer is '\r', we need to wait for more data. These chunks + // could be split up any which way, and it's entirely possible that the next chunk we receive will start + // with '\n', and this trailing '\r' is actually part of a '\r\n' sequence. + if (newline === '\r' && parserRegex.lastIndex === streamBuffer.length && !isLastChunk) { + // Trim off what we've parsed so far, and wait for more data + streamBuffer = streamBuffer.slice(lastLastIndex); + parserRegex.lastIndex = 0; + return; + } + + // https://html.spec.whatwg.org/multipage/server-sent-events.html#processField + if (typeof field === 'string') { + switch (field) { + case 'event': + eventType = value; + break; + case 'data': + // If the data field is empty, there won't be a match for the value. Just add a newline. + if (typeof value === 'string') dataBuffer += value; + dataBuffer += '\n'; + break; + case 'id': + if (!value.includes('\0')) lastEventId = value; + break; + // We do nothing for the `delay` type, and other types are explicitly ignored + } + } else if (typeof comment === 'string') { + continue; + } else { + // https://html.spec.whatwg.org/multipage/server-sent-events.html#dispatchMessage + // Must be a newline. Dispatch the event. + // Skip the event if the data buffer is the empty string. + if (dataBuffer === '') continue; + // Trim the *last* trailing newline + if (dataBuffer[dataBuffer.length - 1] === '\n') { + dataBuffer = dataBuffer.slice(0, -1); + } + const event = new MessageEvent(eventType, { data: dataBuffer, lastEventId }); + controller.enqueue(event); + dataBuffer = ''; + eventType = 'message'; + } + } + } + + const sseStream = new TransformStream({ + transform(chunk, controller) { + streamBuffer += chunk; + processChunk(controller, false); + }, + + flush(controller) { + // If it's the last chunk, trailing carriage returns are allowed + processChunk(controller, true); + }, + }); + + decoder.readable.pipeThrough(sseStream); + + this.readable = sseStream.readable; + this.writable = decoder.writable; + } +} + +export default EventSourceStream; diff --git a/public/scripts/textgen-settings.js b/public/scripts/textgen-settings.js index 53ecc65a3..756ca02d2 100644 --- a/public/scripts/textgen-settings.js +++ b/public/scripts/textgen-settings.js @@ -14,6 +14,7 @@ import { power_user, registerDebugFunction, } from './power-user.js'; +import EventSourceStream from './sse-stream.js'; import { SENTENCEPIECE_TOKENIZERS, getTextTokens, tokenizers } from './tokenizers.js'; import { getSortableDelay, onlyUnique } from './utils.js'; @@ -475,55 +476,30 @@ async function generateTextGenWithStreaming(generate_data, signal) { method: 'POST', signal: signal, }); + const eventStream = new EventSourceStream(); + response.body.pipeThrough(eventStream); + const reader = eventStream.readable.getReader(); return async function* streamData() { - const decoder = new TextDecoder(); - const reader = response.body.getReader(); - let getMessage = ''; - let messageBuffer = ''; + let text = ''; const swipes = []; while (true) { const { done, value } = await reader.read(); - // We don't want carriage returns in our messages - let response = decoder.decode(value).replace(/\r/g, ''); + if (done) return; + if (value.data === '[DONE]') return; - tryParseStreamingError(response); + tryParseStreamingError(response, value.data); - let eventList = []; + let data = JSON.parse(value.data); - messageBuffer += response; - eventList = messageBuffer.split('\n\n'); - // Last element will be an empty string or a leftover partial message - messageBuffer = eventList.pop(); - - for (let event of eventList) { - if (event.startsWith('event: completion')) { - event = event.split('\n')[1]; - } - - if (typeof event !== 'string' || !event.length) - continue; - - if (!event.startsWith('data')) - continue; - if (event == 'data: [DONE]') { - return; - } - let data = JSON.parse(event.substring(6)); - - if (data?.choices[0]?.index > 0) { - const swipeIndex = data.choices[0].index - 1; - swipes[swipeIndex] = (swipes[swipeIndex] || '') + data.choices[0].text; - } else { - getMessage += data?.choices[0]?.text || ''; - } - - yield { text: getMessage, swipes: swipes }; + if (data?.choices[0]?.index > 0) { + const swipeIndex = data.choices[0].index - 1; + swipes[swipeIndex] = (swipes[swipeIndex] || '') + data.choices[0].text; + } else { + text += data?.choices[0]?.text || ''; } - if (done) { - return; - } + yield { text, swipes }; } }; }