Add proper processing of streaming aborting

This commit is contained in:
SillyLossy
2023-04-21 20:29:18 +03:00
parent 9af7c63d9c
commit f25ecbd95c
6 changed files with 62 additions and 10 deletions

View File

@ -259,6 +259,7 @@ class Client {
constructor(auto_reconnect = false, use_cached_bots = false) { constructor(auto_reconnect = false, use_cached_bots = false) {
this.auto_reconnect = auto_reconnect; this.auto_reconnect = auto_reconnect;
this.use_cached_bots = use_cached_bots; this.use_cached_bots = use_cached_bots;
this.abortController = new AbortController();
} }
async init(token, proxy = null) { async init(token, proxy = null) {
@ -267,6 +268,7 @@ class Client {
timeout: 60000, timeout: 60000,
httpAgent: new http.Agent({ keepAlive: true }), httpAgent: new http.Agent({ keepAlive: true }),
httpsAgent: new https.Agent({ keepAlive: true }), httpsAgent: new https.Agent({ keepAlive: true }),
signal: this.abortController.signal,
}); });
if (proxy) { if (proxy) {
this.session.defaults.proxy = { this.session.defaults.proxy = {
@ -544,6 +546,8 @@ class Client {
let messageId; let messageId;
while (true) { while (true) {
try { try {
this.abortController.signal.throwIfAborted();
const message = this.message_queues[humanMessageId].shift(); const message = this.message_queues[humanMessageId].shift();
if (!message) { if (!message) {
await new Promise(resolve => setTimeout(() => resolve(), 1000)); await new Promise(resolve => setTimeout(() => resolve(), 1000));

View File

@ -1273,6 +1273,7 @@ class StreamingProcessor {
this.isStopped = false; this.isStopped = false;
this.isFinished = false; this.isFinished = false;
this.generator = this.nullStreamingGeneration; this.generator = this.nullStreamingGeneration;
this.abortController = new AbortController();
} }
async generate() { async generate() {
@ -1927,7 +1928,7 @@ async function Generate(type, automatic_trigger, force_name2) {
let prompt = await prepareOpenAIMessages(name2, storyString, worldInfoBefore, worldInfoAfter, afterScenarioAnchor, promptBias, type); let prompt = await prepareOpenAIMessages(name2, storyString, worldInfoBefore, worldInfoAfter, afterScenarioAnchor, promptBias, type);
if (isStreamingEnabled()) { if (isStreamingEnabled()) {
streamingProcessor.generator = await sendOpenAIRequest(prompt); streamingProcessor.generator = await sendOpenAIRequest(prompt, streamingProcessor.abortController.signal);
} }
else { else {
sendOpenAIRequest(prompt).then(onSuccess).catch(onError); sendOpenAIRequest(prompt).then(onSuccess).catch(onError);
@ -1938,14 +1939,14 @@ async function Generate(type, automatic_trigger, force_name2) {
} }
else if (main_api == 'poe') { else if (main_api == 'poe') {
if (isStreamingEnabled()) { if (isStreamingEnabled()) {
streamingProcessor.generator = await generatePoe(type, finalPromt); streamingProcessor.generator = await generatePoe(type, finalPromt, streamingProcessor.abortController.signal);
} }
else { else {
generatePoe(type, finalPromt).then(onSuccess).catch(onError); generatePoe(type, finalPromt).then(onSuccess).catch(onError);
} }
} }
else if (main_api == 'textgenerationwebui' && textgenerationwebui_settings.streaming) { else if (main_api == 'textgenerationwebui' && textgenerationwebui_settings.streaming) {
streamingProcessor.generator = await generateTextGenWithStreaming(generate_data); streamingProcessor.generator = await generateTextGenWithStreaming(generate_data, streamingProcessor.abortController.signal);
} }
else { else {
jQuery.ajax({ jQuery.ajax({
@ -5013,6 +5014,7 @@ $(document).ready(function () {
$(document).on("click", ".mes_stop", function () { $(document).on("click", ".mes_stop", function () {
if (streamingProcessor) { if (streamingProcessor) {
streamingProcessor.abortController.abort();
streamingProcessor.isStopped = true; streamingProcessor.isStopped = true;
streamingProcessor.onStopStreaming(); streamingProcessor.onStopStreaming();
streamingProcessor = null; streamingProcessor = null;
@ -5106,4 +5108,11 @@ $(document).ready(function () {
} }
}); });
}); });
$(document).on('beforeunload', () => {
if (streamingProcessor) {
console.log('Page reloaded. Aborting streaming...');
streamingProcessor.abortController.abort();
}
});
}) })

View File

@ -436,7 +436,12 @@ function getSystemPrompt(nsfw_toggle_prompt, enhance_definitions_prompt, wiBefor
return whole_prompt; return whole_prompt;
} }
async function sendOpenAIRequest(openai_msgs_tosend) { async function sendOpenAIRequest(openai_msgs_tosend, signal) {
// Provide default abort signal
if (!signal) {
signal = new AbortController().signal;
}
if (oai_settings.reverse_proxy) { if (oai_settings.reverse_proxy) {
validateReverseProxy(); validateReverseProxy();
} }
@ -459,7 +464,8 @@ async function sendOpenAIRequest(openai_msgs_tosend) {
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
"X-CSRF-Token": token, "X-CSRF-Token": token,
} },
signal: signal,
}); });
if (oai_settings.stream_openai) { if (oai_settings.stream_openai) {

View File

@ -86,7 +86,7 @@ function onBotChange() {
saveSettingsDebounced(); saveSettingsDebounced();
} }
async function generatePoe(type, finalPrompt) { async function generatePoe(type, finalPrompt, signal) {
if (poe_settings.auto_purge) { if (poe_settings.auto_purge) {
let count_to_delete = -1; let count_to_delete = -1;
@ -136,7 +136,7 @@ async function generatePoe(type, finalPrompt) {
finalPrompt = sentences.join(''); finalPrompt = sentences.join('');
} }
const reply = await sendMessage(finalPrompt, true); const reply = await sendMessage(finalPrompt, true, signal);
got_reply = true; got_reply = true;
return reply; return reply;
} }
@ -160,7 +160,11 @@ async function purgeConversation(count = -1) {
return response.ok; return response.ok;
} }
async function sendMessage(prompt, withStreaming) { async function sendMessage(prompt, withStreaming, signal) {
if (!signal) {
signal = new AbortController().signal;
}
const body = JSON.stringify({ const body = JSON.stringify({
bot: poe_settings.bot, bot: poe_settings.bot,
token: poe_settings.token, token: poe_settings.token,
@ -175,6 +179,7 @@ async function sendMessage(prompt, withStreaming) {
}, },
body: body, body: body,
method: 'POST', method: 'POST',
signal: signal,
}); });
if (withStreaming && poe_settings.streaming) { if (withStreaming && poe_settings.streaming) {

View File

@ -147,7 +147,7 @@ function setSettingByName(i, value, trigger) {
} }
} }
async function generateTextGenWithStreaming(generate_data) { async function generateTextGenWithStreaming(generate_data, signal) {
const response = await fetch('/generate_textgenerationwebui', { const response = await fetch('/generate_textgenerationwebui', {
headers: { headers: {
'X-CSRF-Token': token, 'X-CSRF-Token': token,
@ -157,6 +157,7 @@ async function generateTextGenWithStreaming(generate_data) {
}, },
body: JSON.stringify(generate_data), body: JSON.stringify(generate_data),
method: 'POST', method: 'POST',
signal: signal,
}); });
return async function* streamData() { return async function* streamData() {

View File

@ -367,6 +367,10 @@ app.post("/generate_textgenerationwebui", jsonParser, async function (request, r
if (!!request.header('X-Response-Streaming')) { if (!!request.header('X-Response-Streaming')) {
const fn_index = Number(request.header('X-Gradio-Streaming-Function')); const fn_index = Number(request.header('X-Gradio-Streaming-Function'));
let isStreamingStopped = false;
request.socket.on('close', function() {
isStreamingStopped = true;
});
response_generate.writeHead(200, { response_generate.writeHead(200, {
'Content-Type': 'text/plain;charset=utf-8', 'Content-Type': 'text/plain;charset=utf-8',
@ -404,6 +408,12 @@ app.post("/generate_textgenerationwebui", jsonParser, async function (request, r
}); });
while (true) { while (true) {
if (isStreamingStopped) {
console.error('Streaming stopped by user. Closing websocket...');
websocket.close();
return null;
}
if (websocket.readyState == 0 || websocket.readyState == 1 || websocket.readyState == 2) { if (websocket.readyState == 0 || websocket.readyState == 1 || websocket.readyState == 2) {
await delay(50); await delay(50);
yield text; yield text;
@ -1895,6 +1905,12 @@ app.post('/generate_poe', jsonParser, async (request, response) => {
} }
if (streaming) { if (streaming) {
let isStreamingStopped = false;
request.socket.on('close', function() {
isStreamingStopped = true;
client.abortController.abort();
});
try { try {
response.writeHead(200, { response.writeHead(200, {
'Content-Type': 'text/plain;charset=utf-8', 'Content-Type': 'text/plain;charset=utf-8',
@ -1904,6 +1920,11 @@ app.post('/generate_poe', jsonParser, async (request, response) => {
let reply = ''; let reply = '';
for await (const mes of client.send_message(bot, prompt)) { for await (const mes of client.send_message(bot, prompt)) {
if (isStreamingStopped) {
console.error('Streaming stopped by user. Closing websocket...');
break;
}
let newText = mes.text.substring(reply.length); let newText = mes.text.substring(reply.length);
reply = mes.text; reply = mes.text;
response.write(newText); response.write(newText);
@ -2135,6 +2156,11 @@ app.post("/generate_openai", jsonParser, function (request, response_generate_op
if (!request.body) return response_generate_openai.sendStatus(400); if (!request.body) return response_generate_openai.sendStatus(400);
const api_url = new URL(request.body.reverse_proxy || api_openai).toString(); const api_url = new URL(request.body.reverse_proxy || api_openai).toString();
const controller = new AbortController();
request.socket.on('close', function() {
controller.abort();
});
console.log(request.body); console.log(request.body);
const config = { const config = {
method: 'post', method: 'post',
@ -2153,7 +2179,8 @@ app.post("/generate_openai", jsonParser, function (request, response_generate_op
"frequency_penalty": request.body.frequency_penalty, "frequency_penalty": request.body.frequency_penalty,
"stop": request.body.stop, "stop": request.body.stop,
"logit_bias": request.body.logit_bias "logit_bias": request.body.logit_bias
} },
signal: controller.signal,
}; };
if (request.body.stream) if (request.body.stream)