mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-06-05 21:59:27 +02:00
Cohere: new stream parser
This commit is contained in:
126
src/cohere-stream.js
Normal file
126
src/cohere-stream.js
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
const DATA_PREFIX = 'data:';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Borrowed from Cohere SDK (MIT License)
|
||||||
|
* https://github.com/cohere-ai/cohere-typescript/blob/main/src/core/streaming-fetcher/Stream.ts
|
||||||
|
* Copyright (c) 2021 Cohere
|
||||||
|
*/
|
||||||
|
class CohereStream {
|
||||||
|
/** @type {ReadableStream} */
|
||||||
|
stream;
|
||||||
|
/** @type {string} */
|
||||||
|
prefix;
|
||||||
|
/** @type {string} */
|
||||||
|
messageTerminator;
|
||||||
|
/** @type {string|undefined} */
|
||||||
|
streamTerminator;
|
||||||
|
/** @type {AbortController} */
|
||||||
|
controller = new AbortController();
|
||||||
|
|
||||||
|
constructor({ stream, eventShape }) {
|
||||||
|
this.stream = stream;
|
||||||
|
if (eventShape.type === 'sse') {
|
||||||
|
this.prefix = DATA_PREFIX;
|
||||||
|
this.messageTerminator = '\n';
|
||||||
|
this.streamTerminator = eventShape.streamTerminator;
|
||||||
|
} else {
|
||||||
|
this.messageTerminator = eventShape.messageTerminator;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async *iterMessages() {
|
||||||
|
const stream = readableStreamAsyncIterable(this.stream);
|
||||||
|
let buf = '';
|
||||||
|
let prefixSeen = false;
|
||||||
|
let parsedAnyMessages = false;
|
||||||
|
for await (const chunk of stream) {
|
||||||
|
buf += this.decodeChunk(chunk);
|
||||||
|
|
||||||
|
let terminatorIndex;
|
||||||
|
// Parse the chunk into as many messages as possible
|
||||||
|
while ((terminatorIndex = buf.indexOf(this.messageTerminator)) >= 0) {
|
||||||
|
// Extract the line from the buffer
|
||||||
|
let line = buf.slice(0, terminatorIndex + 1);
|
||||||
|
buf = buf.slice(terminatorIndex + 1);
|
||||||
|
|
||||||
|
// Skip empty lines
|
||||||
|
if (line.length === 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip the chunk until the prefix is found
|
||||||
|
if (!prefixSeen && this.prefix != null) {
|
||||||
|
const prefixIndex = line.indexOf(this.prefix);
|
||||||
|
if (prefixIndex === -1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
prefixSeen = true;
|
||||||
|
line = line.slice(prefixIndex + this.prefix.length);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the stream terminator is present, return
|
||||||
|
if (this.streamTerminator != null && line.includes(this.streamTerminator)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, yield message from the prefix to the terminator
|
||||||
|
const message = JSON.parse(line);
|
||||||
|
yield message;
|
||||||
|
prefixSeen = false;
|
||||||
|
parsedAnyMessages = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!parsedAnyMessages && buf.length > 0) {
|
||||||
|
try {
|
||||||
|
yield JSON.parse(buf);
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Error parsing message:', e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async *[Symbol.asyncIterator]() {
|
||||||
|
for await (const message of this.iterMessages()) {
|
||||||
|
yield message;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
decodeChunk(chunk) {
|
||||||
|
const decoder = new TextDecoder('utf8');
|
||||||
|
return decoder.decode(chunk);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function readableStreamAsyncIterable(stream) {
|
||||||
|
if (stream[Symbol.asyncIterator]) {
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
const reader = stream.getReader();
|
||||||
|
return {
|
||||||
|
async next() {
|
||||||
|
try {
|
||||||
|
const result = await reader.read();
|
||||||
|
if (result?.done) {
|
||||||
|
reader.releaseLock();
|
||||||
|
} // release lock when stream becomes closed
|
||||||
|
return result;
|
||||||
|
} catch (e) {
|
||||||
|
reader.releaseLock(); // release lock when stream becomes errored
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
async return() {
|
||||||
|
const cancelPromise = reader.cancel();
|
||||||
|
reader.releaseLock();
|
||||||
|
await cancelPromise;
|
||||||
|
return { done: true, value: undefined };
|
||||||
|
},
|
||||||
|
[Symbol.asyncIterator]() {
|
||||||
|
return this;
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = CohereStream;
|
@@ -6,6 +6,7 @@ const { jsonParser } = require('../../express-common');
|
|||||||
const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY, OPENROUTER_HEADERS } = require('../../constants');
|
const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY, OPENROUTER_HEADERS } = require('../../constants');
|
||||||
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4, mergeObjectWithYaml, excludeKeysByYaml, color } = require('../../util');
|
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4, mergeObjectWithYaml, excludeKeysByYaml, color } = require('../../util');
|
||||||
const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertCohereTools, convertAI21Messages } = require('../../prompt-converters');
|
const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertCohereTools, convertAI21Messages } = require('../../prompt-converters');
|
||||||
|
const CohereStream = require('../../cohere-stream');
|
||||||
|
|
||||||
const { readSecret, SECRET_KEYS } = require('../secrets');
|
const { readSecret, SECRET_KEYS } = require('../secrets');
|
||||||
const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sentencepieceTokenizers, TEXT_COMPLETION_MODELS } = require('../tokenizers');
|
const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sentencepieceTokenizers, TEXT_COMPLETION_MODELS } = require('../tokenizers');
|
||||||
@@ -41,42 +42,30 @@ function postProcessPrompt(messages, type, charName, userName) {
|
|||||||
/**
|
/**
|
||||||
* Ollama strikes back. Special boy #2's steaming routine.
|
* Ollama strikes back. Special boy #2's steaming routine.
|
||||||
* Wrap this abomination into proper SSE stream, again.
|
* Wrap this abomination into proper SSE stream, again.
|
||||||
* @param {import('node-fetch').Response} jsonStream JSON stream
|
* @param {Response} jsonStream JSON stream
|
||||||
* @param {import('express').Request} request Express request
|
* @param {import('express').Request} request Express request
|
||||||
* @param {import('express').Response} response Express response
|
* @param {import('express').Response} response Express response
|
||||||
* @returns {Promise<any>} Nothing valuable
|
* @returns {Promise<any>} Nothing valuable
|
||||||
*/
|
*/
|
||||||
async function parseCohereStream(jsonStream, request, response) {
|
async function parseCohereStream(jsonStream, request, response) {
|
||||||
try {
|
try {
|
||||||
jsonStream.body.on('data', (data) => {
|
const stream = new CohereStream({ stream: jsonStream.body, eventShape: { type: 'json', messageTerminator: '\n' } });
|
||||||
try {
|
|
||||||
const json = JSON.parse(data.toString());
|
for await (const json of stream.iterMessages()) {
|
||||||
if (json.message) {
|
if (json.message) {
|
||||||
const message = json.message || 'Unknown error';
|
const message = json.message || 'Unknown error';
|
||||||
const chunk = { error: { message: message } };
|
const chunk = { error: { message: message } };
|
||||||
response.write(`data: ${JSON.stringify(chunk)}\n\n`);
|
response.write(`data: ${JSON.stringify(chunk)}\n\n`);
|
||||||
} else if (json.event_type === 'text-generation') {
|
} else if (json.event_type === 'text-generation') {
|
||||||
const text = json.text || '';
|
const text = json.text || '';
|
||||||
const chunk = { choices: [{ text }] };
|
const chunk = { choices: [{ text }] };
|
||||||
response.write(`data: ${JSON.stringify(chunk)}\n\n`);
|
response.write(`data: ${JSON.stringify(chunk)}\n\n`);
|
||||||
} else {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
// ignore
|
|
||||||
}
|
}
|
||||||
});
|
}
|
||||||
|
|
||||||
request.socket.on('close', function () {
|
console.log('Streaming request finished');
|
||||||
if (jsonStream.body instanceof Readable) jsonStream.body.destroy();
|
response.write('data: [DONE]\n\n');
|
||||||
response.end();
|
response.end();
|
||||||
});
|
|
||||||
|
|
||||||
jsonStream.body.on('end', () => {
|
|
||||||
console.log('Streaming request finished');
|
|
||||||
response.write('data: [DONE]\n\n');
|
|
||||||
response.end();
|
|
||||||
});
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.log('Error forwarding streaming response:', error);
|
console.log('Error forwarding streaming response:', error);
|
||||||
if (!response.headersSent) {
|
if (!response.headersSent) {
|
||||||
@@ -598,15 +587,15 @@ async function sendCohereRequest(request, response) {
|
|||||||
const apiUrl = API_COHERE + '/chat';
|
const apiUrl = API_COHERE + '/chat';
|
||||||
|
|
||||||
if (request.body.stream) {
|
if (request.body.stream) {
|
||||||
const stream = await fetch(apiUrl, config);
|
const stream = await global.fetch(apiUrl, config);
|
||||||
parseCohereStream(stream, request, response);
|
parseCohereStream(stream, request, response);
|
||||||
} else {
|
} else {
|
||||||
const generateResponse = await fetch(apiUrl, config);
|
const generateResponse = await fetch(apiUrl, config);
|
||||||
if (!generateResponse.ok) {
|
if (!generateResponse.ok) {
|
||||||
console.log(`Cohere API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`);
|
const errorText = await generateResponse.text();
|
||||||
// a 401 unauthorized response breaks the frontend auth, so return a 500 instead. prob a better way of dealing with this.
|
console.log(`Cohere API returned error: ${generateResponse.status} ${generateResponse.statusText} ${errorText}`);
|
||||||
// 401s are already handled by the streaming processor and dont pop up an error toast, that should probably be fixed too.
|
const errorJson = tryParse(errorText) ?? { error: true };
|
||||||
return response.status(generateResponse.status === 401 ? 500 : generateResponse.status).send({ error: true });
|
return response.status(generateResponse.status === 401 ? 500 : generateResponse.status).send(errorJson);
|
||||||
}
|
}
|
||||||
const generateResponseJson = await generateResponse.json();
|
const generateResponseJson = await generateResponse.json();
|
||||||
console.log('Cohere response:', generateResponseJson);
|
console.log('Cohere response:', generateResponseJson);
|
||||||
|
Reference in New Issue
Block a user