mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-01-05 21:46:49 +01:00
Cohere: new stream parser
This commit is contained in:
parent
9286daaf3f
commit
ebab976221
126
src/cohere-stream.js
Normal file
126
src/cohere-stream.js
Normal file
@ -0,0 +1,126 @@
|
||||
const DATA_PREFIX = 'data:';
|
||||
|
||||
/**
|
||||
* Borrowed from Cohere SDK (MIT License)
|
||||
* https://github.com/cohere-ai/cohere-typescript/blob/main/src/core/streaming-fetcher/Stream.ts
|
||||
* Copyright (c) 2021 Cohere
|
||||
*/
|
||||
class CohereStream {
|
||||
/** @type {ReadableStream} */
|
||||
stream;
|
||||
/** @type {string} */
|
||||
prefix;
|
||||
/** @type {string} */
|
||||
messageTerminator;
|
||||
/** @type {string|undefined} */
|
||||
streamTerminator;
|
||||
/** @type {AbortController} */
|
||||
controller = new AbortController();
|
||||
|
||||
constructor({ stream, eventShape }) {
|
||||
this.stream = stream;
|
||||
if (eventShape.type === 'sse') {
|
||||
this.prefix = DATA_PREFIX;
|
||||
this.messageTerminator = '\n';
|
||||
this.streamTerminator = eventShape.streamTerminator;
|
||||
} else {
|
||||
this.messageTerminator = eventShape.messageTerminator;
|
||||
}
|
||||
}
|
||||
|
||||
async *iterMessages() {
|
||||
const stream = readableStreamAsyncIterable(this.stream);
|
||||
let buf = '';
|
||||
let prefixSeen = false;
|
||||
let parsedAnyMessages = false;
|
||||
for await (const chunk of stream) {
|
||||
buf += this.decodeChunk(chunk);
|
||||
|
||||
let terminatorIndex;
|
||||
// Parse the chunk into as many messages as possible
|
||||
while ((terminatorIndex = buf.indexOf(this.messageTerminator)) >= 0) {
|
||||
// Extract the line from the buffer
|
||||
let line = buf.slice(0, terminatorIndex + 1);
|
||||
buf = buf.slice(terminatorIndex + 1);
|
||||
|
||||
// Skip empty lines
|
||||
if (line.length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip the chunk until the prefix is found
|
||||
if (!prefixSeen && this.prefix != null) {
|
||||
const prefixIndex = line.indexOf(this.prefix);
|
||||
if (prefixIndex === -1) {
|
||||
continue;
|
||||
}
|
||||
prefixSeen = true;
|
||||
line = line.slice(prefixIndex + this.prefix.length);
|
||||
}
|
||||
|
||||
// If the stream terminator is present, return
|
||||
if (this.streamTerminator != null && line.includes(this.streamTerminator)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, yield message from the prefix to the terminator
|
||||
const message = JSON.parse(line);
|
||||
yield message;
|
||||
prefixSeen = false;
|
||||
parsedAnyMessages = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!parsedAnyMessages && buf.length > 0) {
|
||||
try {
|
||||
yield JSON.parse(buf);
|
||||
} catch (e) {
|
||||
console.error('Error parsing message:', e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for await (const message of this.iterMessages()) {
|
||||
yield message;
|
||||
}
|
||||
}
|
||||
|
||||
decodeChunk(chunk) {
|
||||
const decoder = new TextDecoder('utf8');
|
||||
return decoder.decode(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
function readableStreamAsyncIterable(stream) {
|
||||
if (stream[Symbol.asyncIterator]) {
|
||||
return stream;
|
||||
}
|
||||
|
||||
const reader = stream.getReader();
|
||||
return {
|
||||
async next() {
|
||||
try {
|
||||
const result = await reader.read();
|
||||
if (result?.done) {
|
||||
reader.releaseLock();
|
||||
} // release lock when stream becomes closed
|
||||
return result;
|
||||
} catch (e) {
|
||||
reader.releaseLock(); // release lock when stream becomes errored
|
||||
throw e;
|
||||
}
|
||||
},
|
||||
async return() {
|
||||
const cancelPromise = reader.cancel();
|
||||
reader.releaseLock();
|
||||
await cancelPromise;
|
||||
return { done: true, value: undefined };
|
||||
},
|
||||
[Symbol.asyncIterator]() {
|
||||
return this;
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = CohereStream;
|
@ -6,6 +6,7 @@ const { jsonParser } = require('../../express-common');
|
||||
const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY, OPENROUTER_HEADERS } = require('../../constants');
|
||||
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4, mergeObjectWithYaml, excludeKeysByYaml, color } = require('../../util');
|
||||
const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertCohereTools, convertAI21Messages } = require('../../prompt-converters');
|
||||
const CohereStream = require('../../cohere-stream');
|
||||
|
||||
const { readSecret, SECRET_KEYS } = require('../secrets');
|
||||
const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sentencepieceTokenizers, TEXT_COMPLETION_MODELS } = require('../tokenizers');
|
||||
@ -41,42 +42,30 @@ function postProcessPrompt(messages, type, charName, userName) {
|
||||
/**
|
||||
* Ollama strikes back. Special boy #2's steaming routine.
|
||||
* Wrap this abomination into proper SSE stream, again.
|
||||
* @param {import('node-fetch').Response} jsonStream JSON stream
|
||||
* @param {Response} jsonStream JSON stream
|
||||
* @param {import('express').Request} request Express request
|
||||
* @param {import('express').Response} response Express response
|
||||
* @returns {Promise<any>} Nothing valuable
|
||||
*/
|
||||
async function parseCohereStream(jsonStream, request, response) {
|
||||
try {
|
||||
jsonStream.body.on('data', (data) => {
|
||||
try {
|
||||
const json = JSON.parse(data.toString());
|
||||
if (json.message) {
|
||||
const message = json.message || 'Unknown error';
|
||||
const chunk = { error: { message: message } };
|
||||
response.write(`data: ${JSON.stringify(chunk)}\n\n`);
|
||||
} else if (json.event_type === 'text-generation') {
|
||||
const text = json.text || '';
|
||||
const chunk = { choices: [{ text }] };
|
||||
response.write(`data: ${JSON.stringify(chunk)}\n\n`);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
} catch (e) {
|
||||
// ignore
|
||||
const stream = new CohereStream({ stream: jsonStream.body, eventShape: { type: 'json', messageTerminator: '\n' } });
|
||||
|
||||
for await (const json of stream.iterMessages()) {
|
||||
if (json.message) {
|
||||
const message = json.message || 'Unknown error';
|
||||
const chunk = { error: { message: message } };
|
||||
response.write(`data: ${JSON.stringify(chunk)}\n\n`);
|
||||
} else if (json.event_type === 'text-generation') {
|
||||
const text = json.text || '';
|
||||
const chunk = { choices: [{ text }] };
|
||||
response.write(`data: ${JSON.stringify(chunk)}\n\n`);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
request.socket.on('close', function () {
|
||||
if (jsonStream.body instanceof Readable) jsonStream.body.destroy();
|
||||
response.end();
|
||||
});
|
||||
|
||||
jsonStream.body.on('end', () => {
|
||||
console.log('Streaming request finished');
|
||||
response.write('data: [DONE]\n\n');
|
||||
response.end();
|
||||
});
|
||||
console.log('Streaming request finished');
|
||||
response.write('data: [DONE]\n\n');
|
||||
response.end();
|
||||
} catch (error) {
|
||||
console.log('Error forwarding streaming response:', error);
|
||||
if (!response.headersSent) {
|
||||
@ -598,15 +587,15 @@ async function sendCohereRequest(request, response) {
|
||||
const apiUrl = API_COHERE + '/chat';
|
||||
|
||||
if (request.body.stream) {
|
||||
const stream = await fetch(apiUrl, config);
|
||||
const stream = await global.fetch(apiUrl, config);
|
||||
parseCohereStream(stream, request, response);
|
||||
} else {
|
||||
const generateResponse = await fetch(apiUrl, config);
|
||||
if (!generateResponse.ok) {
|
||||
console.log(`Cohere API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`);
|
||||
// a 401 unauthorized response breaks the frontend auth, so return a 500 instead. prob a better way of dealing with this.
|
||||
// 401s are already handled by the streaming processor and dont pop up an error toast, that should probably be fixed too.
|
||||
return response.status(generateResponse.status === 401 ? 500 : generateResponse.status).send({ error: true });
|
||||
const errorText = await generateResponse.text();
|
||||
console.log(`Cohere API returned error: ${generateResponse.status} ${generateResponse.statusText} ${errorText}`);
|
||||
const errorJson = tryParse(errorText) ?? { error: true };
|
||||
return response.status(generateResponse.status === 401 ? 500 : generateResponse.status).send(errorJson);
|
||||
}
|
||||
const generateResponseJson = await generateResponse.json();
|
||||
console.log('Cohere response:', generateResponseJson);
|
||||
|
Loading…
Reference in New Issue
Block a user