Cohere: new stream parser

This commit is contained in:
Cohee 2024-08-27 12:03:31 +03:00
parent 9286daaf3f
commit ebab976221
2 changed files with 148 additions and 33 deletions

126
src/cohere-stream.js Normal file
View File

@ -0,0 +1,126 @@
const DATA_PREFIX = 'data:';
/**
* Borrowed from Cohere SDK (MIT License)
* https://github.com/cohere-ai/cohere-typescript/blob/main/src/core/streaming-fetcher/Stream.ts
* Copyright (c) 2021 Cohere
*/
class CohereStream {
/** @type {ReadableStream} */
stream;
/** @type {string} */
prefix;
/** @type {string} */
messageTerminator;
/** @type {string|undefined} */
streamTerminator;
/** @type {AbortController} */
controller = new AbortController();
constructor({ stream, eventShape }) {
this.stream = stream;
if (eventShape.type === 'sse') {
this.prefix = DATA_PREFIX;
this.messageTerminator = '\n';
this.streamTerminator = eventShape.streamTerminator;
} else {
this.messageTerminator = eventShape.messageTerminator;
}
}
async *iterMessages() {
const stream = readableStreamAsyncIterable(this.stream);
let buf = '';
let prefixSeen = false;
let parsedAnyMessages = false;
for await (const chunk of stream) {
buf += this.decodeChunk(chunk);
let terminatorIndex;
// Parse the chunk into as many messages as possible
while ((terminatorIndex = buf.indexOf(this.messageTerminator)) >= 0) {
// Extract the line from the buffer
let line = buf.slice(0, terminatorIndex + 1);
buf = buf.slice(terminatorIndex + 1);
// Skip empty lines
if (line.length === 0) {
continue;
}
// Skip the chunk until the prefix is found
if (!prefixSeen && this.prefix != null) {
const prefixIndex = line.indexOf(this.prefix);
if (prefixIndex === -1) {
continue;
}
prefixSeen = true;
line = line.slice(prefixIndex + this.prefix.length);
}
// If the stream terminator is present, return
if (this.streamTerminator != null && line.includes(this.streamTerminator)) {
return;
}
// Otherwise, yield message from the prefix to the terminator
const message = JSON.parse(line);
yield message;
prefixSeen = false;
parsedAnyMessages = true;
}
}
if (!parsedAnyMessages && buf.length > 0) {
try {
yield JSON.parse(buf);
} catch (e) {
console.error('Error parsing message:', e);
}
}
}
async *[Symbol.asyncIterator]() {
for await (const message of this.iterMessages()) {
yield message;
}
}
decodeChunk(chunk) {
const decoder = new TextDecoder('utf8');
return decoder.decode(chunk);
}
}
function readableStreamAsyncIterable(stream) {
if (stream[Symbol.asyncIterator]) {
return stream;
}
const reader = stream.getReader();
return {
async next() {
try {
const result = await reader.read();
if (result?.done) {
reader.releaseLock();
} // release lock when stream becomes closed
return result;
} catch (e) {
reader.releaseLock(); // release lock when stream becomes errored
throw e;
}
},
async return() {
const cancelPromise = reader.cancel();
reader.releaseLock();
await cancelPromise;
return { done: true, value: undefined };
},
[Symbol.asyncIterator]() {
return this;
},
};
}
module.exports = CohereStream;

View File

@ -6,6 +6,7 @@ const { jsonParser } = require('../../express-common');
const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY, OPENROUTER_HEADERS } = require('../../constants');
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4, mergeObjectWithYaml, excludeKeysByYaml, color } = require('../../util');
const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertCohereTools, convertAI21Messages } = require('../../prompt-converters');
const CohereStream = require('../../cohere-stream');
const { readSecret, SECRET_KEYS } = require('../secrets');
const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sentencepieceTokenizers, TEXT_COMPLETION_MODELS } = require('../tokenizers');
@ -41,16 +42,16 @@ function postProcessPrompt(messages, type, charName, userName) {
/**
* Ollama strikes back. Special boy #2's steaming routine.
* Wrap this abomination into proper SSE stream, again.
* @param {import('node-fetch').Response} jsonStream JSON stream
* @param {Response} jsonStream JSON stream
* @param {import('express').Request} request Express request
* @param {import('express').Response} response Express response
* @returns {Promise<any>} Nothing valuable
*/
async function parseCohereStream(jsonStream, request, response) {
try {
jsonStream.body.on('data', (data) => {
try {
const json = JSON.parse(data.toString());
const stream = new CohereStream({ stream: jsonStream.body, eventShape: { type: 'json', messageTerminator: '\n' } });
for await (const json of stream.iterMessages()) {
if (json.message) {
const message = json.message || 'Unknown error';
const chunk = { error: { message: message } };
@ -59,24 +60,12 @@ async function parseCohereStream(jsonStream, request, response) {
const text = json.text || '';
const chunk = { choices: [{ text }] };
response.write(`data: ${JSON.stringify(chunk)}\n\n`);
} else {
return;
}
} catch (e) {
// ignore
}
});
request.socket.on('close', function () {
if (jsonStream.body instanceof Readable) jsonStream.body.destroy();
response.end();
});
jsonStream.body.on('end', () => {
console.log('Streaming request finished');
response.write('data: [DONE]\n\n');
response.end();
});
} catch (error) {
console.log('Error forwarding streaming response:', error);
if (!response.headersSent) {
@ -598,15 +587,15 @@ async function sendCohereRequest(request, response) {
const apiUrl = API_COHERE + '/chat';
if (request.body.stream) {
const stream = await fetch(apiUrl, config);
const stream = await global.fetch(apiUrl, config);
parseCohereStream(stream, request, response);
} else {
const generateResponse = await fetch(apiUrl, config);
if (!generateResponse.ok) {
console.log(`Cohere API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`);
// a 401 unauthorized response breaks the frontend auth, so return a 500 instead. prob a better way of dealing with this.
// 401s are already handled by the streaming processor and dont pop up an error toast, that should probably be fixed too.
return response.status(generateResponse.status === 401 ? 500 : generateResponse.status).send({ error: true });
const errorText = await generateResponse.text();
console.log(`Cohere API returned error: ${generateResponse.status} ${generateResponse.statusText} ${errorText}`);
const errorJson = tryParse(errorText) ?? { error: true };
return response.status(generateResponse.status === 401 ? 500 : generateResponse.status).send(errorJson);
}
const generateResponseJson = await generateResponse.json();
console.log('Cohere response:', generateResponseJson);