From d265179f4692a53ec29adee4a49731b936403f1a Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Sun, 8 Oct 2023 23:42:28 +0300 Subject: [PATCH] Don't crash ST server on invalid streaming URL --- public/scripts/textgen-settings.js | 16 ++++++++++++++++ server.js | 14 ++++++++++++-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/public/scripts/textgen-settings.js b/public/scripts/textgen-settings.js index 89a6bca65..49457d8af 100644 --- a/public/scripts/textgen-settings.js +++ b/public/scripts/textgen-settings.js @@ -227,6 +227,10 @@ export function isAphrodite() { return textgenerationwebui_settings.type === textgen_types.APHRODITE; } +export function isOoba() { + return textgenerationwebui_settings.type === textgen_types.OOBA; +} + export function getTextGenUrlSourceId() { switch (textgenerationwebui_settings.type) { case textgen_types.MANCER: @@ -327,6 +331,18 @@ async function generateTextGenWithStreaming(generate_data, signal) { streamingUrl = api_server_textgenerationwebui; } + if (isMancer() || isOoba()) { + try { + const parsedUrl = new URL(streamingUrl); + if (parsedUrl.protocol !== 'ws:' && parsedUrl.protocol !== 'wss:') { + throw new Error('Invalid protocol'); + } + } catch { + toastr.error('Invalid URL for streaming. Make sure it starts with ws:// or wss://'); + return async function* () { throw new Error('Invalid URL for streaming.'); } + } + } + const response = await fetch('/generate_textgenerationwebui', { headers: { ...getRequestHeaders(), diff --git a/server.js b/server.js index b769ee605..12fc3fc3f 100644 --- a/server.js +++ b/server.js @@ -552,8 +552,18 @@ app.post("/generate_textgenerationwebui", jsonParser, async function (request, r }); async function* readWebsocket() { - const streamingUrl = new URL(streamingUrlString); - const websocket = new WebSocket(streamingUrl); + /** @type {WebSocket} */ + let websocket; + /** @type {URL} */ + let streamingUrl; + + try { + const streamingUrl = new URL(streamingUrlString); + websocket = new WebSocket(streamingUrl); + } catch (error) { + console.log("[SillyTavern] Socket error", error); + return; + } websocket.on('open', async function () { console.log('WebSocket opened');