added streaming for google models

This commit is contained in:
based
2023-12-14 21:03:41 +10:00
parent 3e82a7d439
commit ca87f29771
3 changed files with 86 additions and 40 deletions

View File

@ -2557,7 +2557,7 @@ function getCharacterCardFields() {
} }
function isStreamingEnabled() { function isStreamingEnabled() {
const noStreamSources = [chat_completion_sources.SCALE, chat_completion_sources.AI21, chat_completion_sources.MAKERSUITE]; const noStreamSources = [chat_completion_sources.SCALE, chat_completion_sources.AI21];
return ((main_api == 'openai' && oai_settings.stream_openai && !noStreamSources.includes(oai_settings.chat_completion_source)) return ((main_api == 'openai' && oai_settings.stream_openai && !noStreamSources.includes(oai_settings.chat_completion_source))
|| (main_api == 'kobold' && kai_settings.streaming_kobold && kai_flags.can_use_streaming) || (main_api == 'kobold' && kai_settings.streaming_kobold && kai_flags.can_use_streaming)
|| (main_api == 'novel' && nai_settings.streaming_novel) || (main_api == 'novel' && nai_settings.streaming_novel)

View File

@ -1452,7 +1452,7 @@ async function sendOpenAIRequest(type, messages, signal) {
const isQuiet = type === 'quiet'; const isQuiet = type === 'quiet';
const isImpersonate = type === 'impersonate'; const isImpersonate = type === 'impersonate';
const isContinue = type === 'continue'; const isContinue = type === 'continue';
const stream = oai_settings.stream_openai && !isQuiet && !isScale && !isAI21 && !isGoogle; const stream = oai_settings.stream_openai && !isQuiet && !isScale && !isAI21;
if (isTextCompletion && isOpenRouter) { if (isTextCompletion && isOpenRouter) {
messages = convertChatCompletionToInstruct(messages, type); messages = convertChatCompletionToInstruct(messages, type);
@ -1571,23 +1571,26 @@ async function sendOpenAIRequest(type, messages, signal) {
tryParseStreamingError(response, await response.text()); tryParseStreamingError(response, await response.text());
throw new Error(`Got response status ${response.status}`); throw new Error(`Got response status ${response.status}`);
} }
if (stream) { if (stream) {
const eventStream = new EventSourceStream(); let reader;
response.body.pipeThrough(eventStream); let isSSEStream = oai_settings.chat_completion_source !== chat_completion_sources.MAKERSUITE;
const reader = eventStream.readable.getReader(); if (isSSEStream) {
const eventStream = new EventSourceStream();
response.body.pipeThrough(eventStream);
reader = eventStream.readable.getReader();
} else {
reader = response.body.getReader();
}
return async function* streamData() { return async function* streamData() {
let text = ''; let text = '';
let utf8Decoder = new TextDecoder();
while (true) { while (true) {
const { done, value } = await reader.read(); const { done, value } = await reader.read();
if (done) return; if (done) return;
if (value.data === '[DONE]') return; const rawData = isSSEStream ? value.data : utf8Decoder.decode(value, { stream: true });
if (isSSEStream && rawData === '[DONE]') return;
tryParseStreamingError(response, value.data); tryParseStreamingError(response, rawData);
text += getStreamingReply(JSON.parse(rawData));
// the first and last messages are undefined, protect against that
text += getStreamingReply(JSON.parse(value.data));
yield { text, swipes: [] }; yield { text, swipes: [] };
} }
}; };
@ -1609,6 +1612,8 @@ async function sendOpenAIRequest(type, messages, signal) {
function getStreamingReply(data) { function getStreamingReply(data) {
if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) { if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) {
return data?.completion || ''; return data?.completion || '';
} else if (oai_settings.chat_completion_source == chat_completion_sources.MAKERSUITE) {
return data?.candidates[0].content.parts[0].text || '';
} else { } else {
return data.choices[0]?.delta?.content || data.choices[0]?.message?.content || data.choices[0]?.text || ''; return data.choices[0]?.delta?.content || data.choices[0]?.message?.content || data.choices[0]?.text || '';
} }

View File

@ -1018,6 +1018,7 @@ async function sendMakerSuiteRequest(request, response) {
}; };
const google_model = request.body.model; const google_model = request.body.model;
const should_stream = request.body.stream;
try { try {
const controller = new AbortController(); const controller = new AbortController();
request.socket.removeAllListeners('close'); request.socket.removeAllListeners('close');
@ -1025,7 +1026,7 @@ async function sendMakerSuiteRequest(request, response) {
controller.abort(); controller.abort();
}); });
const generateResponse = await fetch(`https://generativelanguage.googleapis.com/v1beta/models/${google_model}:generateContent?key=${api_key_makersuite}`, { const generateResponse = await fetch(`https://generativelanguage.googleapis.com/v1beta/models/${google_model}:${should_stream ? 'streamGenerateContent' : 'generateContent'}?key=${api_key_makersuite}`, {
body: JSON.stringify(body), body: JSON.stringify(body),
method: 'POST', method: 'POST',
headers: { headers: {
@ -1034,37 +1035,77 @@ async function sendMakerSuiteRequest(request, response) {
signal: controller.signal, signal: controller.signal,
timeout: 0, timeout: 0,
}); });
// have to do this because of their busted ass streaming endpoint
if (should_stream) {
try {
let partialData = '';
generateResponse.body.on('data', (data) => {
const chunk = data.toString();
if (chunk.startsWith(',') || chunk.endsWith(',') || chunk.startsWith('[') || chunk.endsWith(']')) {
partialData = chunk.slice(1);
} else {
partialData += chunk;
}
while (true) {
let json;
try {
json = JSON.parse(partialData);
} catch (e) {
break;
}
response.write(JSON.stringify(json));
partialData = '';
}
});
if (!generateResponse.ok) { request.socket.on('close', function () {
console.log(`MakerSuite API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`); generateResponse.body.destroy();
return response.status(generateResponse.status).send({ error: true }); response.end();
} });
const generateResponseJson = await generateResponse.json(); generateResponse.body.on('end', () => {
console.log('Streaming request finished');
response.end();
});
const candidates = generateResponseJson?.candidates; } catch (error) {
if (!candidates || candidates.length === 0) { console.log('Error forwarding streaming response:', error);
let message = 'MakerSuite API returned no candidate'; if (!response.headersSent) {
console.log(message, generateResponseJson); return response.status(500).send({ error: true });
if (generateResponseJson?.promptFeedback?.blockReason) { }
message += `\nPrompt was blocked due to : ${generateResponseJson.promptFeedback.blockReason}`;
} }
return response.send({ error: { message } }); } else {
if (!generateResponse.ok) {
console.log(`MakerSuite API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`);
return response.status(generateResponse.status).send({ error: true });
}
const generateResponseJson = await generateResponse.json();
const candidates = generateResponseJson?.candidates;
if (!candidates || candidates.length === 0) {
let message = 'MakerSuite API returned no candidate';
console.log(message, generateResponseJson);
if (generateResponseJson?.promptFeedback?.blockReason) {
message += `\nPrompt was blocked due to : ${generateResponseJson.promptFeedback.blockReason}`;
}
return response.send({ error: { message } });
}
const responseContent = candidates[0].content;
const responseText = responseContent.parts[0].text;
if (!responseText) {
let message = 'MakerSuite Candidate text empty';
console.log(message, generateResponseJson);
return response.send({ error: { message } });
}
console.log('MakerSuite response:', responseText);
// Wrap it back to OAI format
const reply = { choices: [{ 'message': { 'content': responseText } }] };
return response.send(reply);
} }
const responseContent = candidates[0].content;
const responseText = responseContent.parts[0].text;
if (!responseText) {
let message = 'MakerSuite Candidate text empty';
console.log(message, generateResponseJson);
return response.send({ error: { message } });
}
console.log('MakerSuite response:', responseText);
// Wrap it back to OAI format
const reply = { choices: [{ 'message': { 'content': responseText } }] };
return response.send(reply);
} catch (error) { } catch (error) {
console.log('Error communicating with MakerSuite API: ', error); console.log('Error communicating with MakerSuite API: ', error);
if (!response.headersSent) { if (!response.headersSent) {