Add oobabooga streaming

This commit is contained in:
SillyLossy
2023-04-12 19:17:02 +03:00
parent 83982cf1fc
commit d495503ac1
4 changed files with 220 additions and 47 deletions

154
server.js
View File

@@ -39,6 +39,7 @@ const listen = config.listen;
const axios = require('axios');
const tiktoken = require('@dqbd/tiktoken');
const WebSocket = require('ws');
var Client = require('node-rest-client').Client;
var client = new Client();
@@ -308,34 +309,141 @@ app.post("/generate", jsonParser, async function (request, response_generate = r
}
});
function randomHash() {
const letters = 'abcdefghijklmnopqrstuvwxyz0123456789';
let result = '';
for (let i = 0; i < 9; i++) {
result += letters.charAt(Math.floor(Math.random() * letters.length));
}
return result;
}
function textGenProcessStartedHandler(websocket, content, session, prompt, SEND_PROMPT_GRADIO_FN) {
switch (content.msg) {
case "send_hash":
const send_hash = JSON.stringify({ "session_hash": session, "fn_index": SEND_PROMPT_GRADIO_FN });
websocket.send(send_hash);
break;
case "estimation":
break;
case "send_data":
const send_data = JSON.stringify({ "session_hash": session, "fn_index": SEND_PROMPT_GRADIO_FN, "data": prompt.data });
console.log(send_data);
websocket.send(send_data);
break;
case "process_starts":
break;
case "process_generating":
return content.output.data[0];
case "process_completed":
return null;
}
return '';
}
//************** Text generation web UI
app.post("/generate_textgenerationwebui", jsonParser, function (request, response_generate = response) {
app.post("/generate_textgenerationwebui", jsonParser, async function (request, response_generate = response) {
if (!request.body) return response_generate.sendStatus(400);
console.log(request.body);
var args = {
data: request.body,
headers: { "Content-Type": "application/json" }
};
client.post(api_server + "/run/textgen", args, function (data, response) {
console.log("####", data);
if (response.statusCode == 200) {
console.log(data);
response_generate.send(data);
if (!!request.header('X-Response-Streaming')) {
const SEND_PARAMS_GRADIO_FN = 29;
response_generate.writeHead(200, {
'Transfer-Encoding': 'chunked',
'Cache-Control': 'no-transform',
});
async function* readWebsocket() {
const session = randomHash();
const url = new URL(api_server);
const websocket = new WebSocket(`ws://${url.host}/queue/join`, { perMessageDeflate: false });
let text = '';
websocket.on('open', async function() {
console.log('websocket open');
});
websocket.on('error', (err) => {
console.error(err);
websocket.close();
});
websocket.on('close', (code, buffer) => {
const reason = new TextDecoder().decode(buffer)
console.log(reason);
});
websocket.on('message', async (message) => {
const content = json5.parse(message);
console.log(content);
text = textGenProcessStartedHandler(websocket, content, session, request.body, SEND_PARAMS_GRADIO_FN);
});
while (true) {
if (websocket.readyState == 0 || websocket.readyState == 1 || websocket.readyState == 2) {
await delay(50);
yield text;
if (!text && typeof text !== 'string') {
websocket.close();
}
}
else {
break;
}
}
}
if (response.statusCode == 422) {
console.log('Validation error');
let result = json5.parse(request.body.data)[0];
try {
for await (const text of readWebsocket()) {
if (text == null) {
break;
}
let newText = text.substring(result.length);
if (!newText) {
continue;
}
result = text;
response_generate.write(newText);
}
}
finally {
response_generate.end();
}
}
else {
var args = {
data: request.body,
headers: { "Content-Type": "application/json" }
};
client.post(api_server + "/run/textgen", args, function (data, response) {
console.log("####", data);
if (response.statusCode == 200) {
console.log(data);
response_generate.send(data);
}
if (response.statusCode == 422) {
console.log('Validation error');
response_generate.send({ error: true });
}
if (response.statusCode == 501 || response.statusCode == 503 || response.statusCode == 507) {
console.log(data);
response_generate.send({ error: true });
}
}).on('error', function (err) {
console.log(err);
//console.log('something went wrong on the request', err.request.options);
response_generate.send({ error: true });
}
if (response.statusCode == 501 || response.statusCode == 503 || response.statusCode == 507) {
console.log(data);
response_generate.send({ error: true });
}
}).on('error', function (err) {
console.log(err);
//console.log('something went wrong on the request', err.request.options);
response_generate.send({ error: true });
});
});
}
});
@@ -447,7 +555,7 @@ app.post("/getstatus", jsonParser, function (request, response_getstatus = respo
if (!response)
throw "no_connection";
let model = json5.parse(response).components.filter((x) => x.props.label == "Model" && x.type == "dropdown")[0].props.value;
data = { result: model };
data = { result: model, gradio_config: response };
if (!data)
throw "no_connection";
} catch {