Streaming for Claude.

This commit is contained in:
Cohee
2023-06-06 20:18:28 +03:00
parent e205323482
commit a1b130fc9a
3 changed files with 75 additions and 38 deletions

View File

@@ -1562,7 +1562,7 @@ function appendToStoryString(value, prefix) {
} }
function isStreamingEnabled() { function isStreamingEnabled() {
return ((main_api == 'openai' && oai_settings.stream_openai && oai_settings.chat_completion_source !== chat_completion_sources.CLAUDE) return ((main_api == 'openai' && oai_settings.stream_openai)
|| (main_api == 'poe' && poe_settings.streaming) || (main_api == 'poe' && poe_settings.streaming)
|| (main_api == 'textgenerationwebui' && textgenerationwebui_settings.streaming)) || (main_api == 'textgenerationwebui' && textgenerationwebui_settings.streaming))
&& !isMultigenEnabled(); // Multigen has a quasi-streaming mode which breaks the real streaming && !isMultigenEnabled(); // Multigen has a quasi-streaming mode which breaks the real streaming

View File

@@ -638,7 +638,7 @@ async function sendOpenAIRequest(type, openai_msgs_tosend, signal) {
let logit_bias = {}; let logit_bias = {};
const isClaude = oai_settings.chat_completion_source == chat_completion_sources.CLAUDE; const isClaude = oai_settings.chat_completion_source == chat_completion_sources.CLAUDE;
const stream = type !== 'quiet' && oai_settings.stream_openai && !isClaude; const stream = type !== 'quiet' && oai_settings.stream_openai;
// If we're using the window.ai extension, use that instead // If we're using the window.ai extension, use that instead
// Doesn't support logit bias yet // Doesn't support logit bias yet
@@ -687,6 +687,11 @@ async function sendOpenAIRequest(type, openai_msgs_tosend, signal) {
const { done, value } = await reader.read(); const { done, value } = await reader.read();
let response = decoder.decode(value); let response = decoder.decode(value);
// Claude's streaming SSE messages are separated by \r
if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) {
response = response.replace(/\r/g, "");
}
tryParseStreamingError(response); tryParseStreamingError(response);
let eventList = []; let eventList = [];
@@ -710,7 +715,7 @@ async function sendOpenAIRequest(type, openai_msgs_tosend, signal) {
} }
let data = JSON.parse(event.substring(6)); let data = JSON.parse(event.substring(6));
// the first and last messages are undefined, protect against that // the first and last messages are undefined, protect against that
getMessage += data.choices[0]["delta"]["content"] || ""; getMessage = getStreamingReply(getMessage, data);
yield getMessage; yield getMessage;
} }
@@ -734,6 +739,15 @@ async function sendOpenAIRequest(type, openai_msgs_tosend, signal) {
} }
} }
function getStreamingReply(getMessage, data) {
if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) {
getMessage = data.completion || "";
} else{
getMessage += data.choices[0]["delta"]["content"] || "";
}
return getMessage;
}
function handleWindowError(err) { function handleWindowError(err) {
const text = parseWindowError(err); const text = parseWindowError(err);
toastr.error(text, 'Window.ai returned an error'); toastr.error(text, 'Window.ai returned an error');

View File

@@ -2725,6 +2725,7 @@ async function sendClaudeRequest(request, response) {
return response.status(401).send({ error: true }); return response.status(401).send({ error: true });
} }
try {
const controller = new AbortController(); const controller = new AbortController();
request.socket.removeAllListeners('close'); request.socket.removeAllListeners('close');
request.socket.on('close', function () { request.socket.on('close', function () {
@@ -2743,6 +2744,7 @@ async function sendClaudeRequest(request, response) {
max_tokens_to_sample: request.body.max_tokens, max_tokens_to_sample: request.body.max_tokens,
stop_sequences: ["\n\nHuman:", "\n\nSystem:", "\n\nAssistant:"], stop_sequences: ["\n\nHuman:", "\n\nSystem:", "\n\nAssistant:"],
temperature: request.body.temperature, temperature: request.body.temperature,
stream: request.body.stream,
}), }),
headers: { headers: {
"Content-Type": "application/json", "Content-Type": "application/json",
@@ -2750,6 +2752,20 @@ async function sendClaudeRequest(request, response) {
} }
}); });
if (request.body.stream) {
// Pipe remote SSE stream to Express response
generateResponse.body.pipe(response);
request.socket.on('close', function () {
generateResponse.body.destroy(); // Close the remote stream
response.end(); // End the Express response
});
generateResponse.body.on('end', function () {
console.log("Streaming request finished");
response.end();
});
} else {
if (!generateResponse.ok) { if (!generateResponse.ok) {
console.log(`Claude API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`); console.log(`Claude API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`);
return response.status(generateResponse.status).send({ error: true }); return response.status(generateResponse.status).send({ error: true });
@@ -2763,6 +2779,13 @@ async function sendClaudeRequest(request, response) {
const reply = { choices: [{ "message": { "content": responseText, } }] }; const reply = { choices: [{ "message": { "content": responseText, } }] };
return response.send(reply); return response.send(reply);
} }
} catch (error) {
console.log('Error communicating with Claude: ', error);
if (!response.headersSent) {
return response.status(500).send({ error: true });
}
}
}
app.post("/generate_openai", jsonParser, function (request, response_generate_openai) { app.post("/generate_openai", jsonParser, function (request, response_generate_openai) {
if (!request.body) return response_generate_openai.status(400).send({ error: true }); if (!request.body) return response_generate_openai.status(400).send({ error: true });