mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-06-05 21:59:27 +02:00
@@ -1,126 +0,0 @@
|
||||
const DATA_PREFIX = 'data:';
|
||||
|
||||
/**
|
||||
* Borrowed from Cohere SDK (MIT License)
|
||||
* https://github.com/cohere-ai/cohere-typescript/blob/main/src/core/streaming-fetcher/Stream.ts
|
||||
* Copyright (c) 2021 Cohere
|
||||
*/
|
||||
class CohereStream {
|
||||
/** @type {ReadableStream} */
|
||||
stream;
|
||||
/** @type {string} */
|
||||
prefix;
|
||||
/** @type {string} */
|
||||
messageTerminator;
|
||||
/** @type {string|undefined} */
|
||||
streamTerminator;
|
||||
/** @type {AbortController} */
|
||||
controller = new AbortController();
|
||||
|
||||
constructor({ stream, eventShape }) {
|
||||
this.stream = stream;
|
||||
if (eventShape.type === 'sse') {
|
||||
this.prefix = DATA_PREFIX;
|
||||
this.messageTerminator = '\n';
|
||||
this.streamTerminator = eventShape.streamTerminator;
|
||||
} else {
|
||||
this.messageTerminator = eventShape.messageTerminator;
|
||||
}
|
||||
}
|
||||
|
||||
async *iterMessages() {
|
||||
const stream = readableStreamAsyncIterable(this.stream);
|
||||
let buf = '';
|
||||
let prefixSeen = false;
|
||||
let parsedAnyMessages = false;
|
||||
for await (const chunk of stream) {
|
||||
buf += this.decodeChunk(chunk);
|
||||
|
||||
let terminatorIndex;
|
||||
// Parse the chunk into as many messages as possible
|
||||
while ((terminatorIndex = buf.indexOf(this.messageTerminator)) >= 0) {
|
||||
// Extract the line from the buffer
|
||||
let line = buf.slice(0, terminatorIndex + 1);
|
||||
buf = buf.slice(terminatorIndex + 1);
|
||||
|
||||
// Skip empty lines
|
||||
if (line.length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip the chunk until the prefix is found
|
||||
if (!prefixSeen && this.prefix != null) {
|
||||
const prefixIndex = line.indexOf(this.prefix);
|
||||
if (prefixIndex === -1) {
|
||||
continue;
|
||||
}
|
||||
prefixSeen = true;
|
||||
line = line.slice(prefixIndex + this.prefix.length);
|
||||
}
|
||||
|
||||
// If the stream terminator is present, return
|
||||
if (this.streamTerminator != null && line.includes(this.streamTerminator)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, yield message from the prefix to the terminator
|
||||
const message = JSON.parse(line);
|
||||
yield message;
|
||||
prefixSeen = false;
|
||||
parsedAnyMessages = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!parsedAnyMessages && buf.length > 0) {
|
||||
try {
|
||||
yield JSON.parse(buf);
|
||||
} catch (e) {
|
||||
console.error('Error parsing message:', e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for await (const message of this.iterMessages()) {
|
||||
yield message;
|
||||
}
|
||||
}
|
||||
|
||||
decodeChunk(chunk) {
|
||||
const decoder = new TextDecoder('utf8');
|
||||
return decoder.decode(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
function readableStreamAsyncIterable(stream) {
|
||||
if (stream[Symbol.asyncIterator]) {
|
||||
return stream;
|
||||
}
|
||||
|
||||
const reader = stream.getReader();
|
||||
return {
|
||||
async next() {
|
||||
try {
|
||||
const result = await reader.read();
|
||||
if (result?.done) {
|
||||
reader.releaseLock();
|
||||
} // release lock when stream becomes closed
|
||||
return result;
|
||||
} catch (e) {
|
||||
reader.releaseLock(); // release lock when stream becomes errored
|
||||
throw e;
|
||||
}
|
||||
},
|
||||
async return() {
|
||||
const cancelPromise = reader.cancel();
|
||||
reader.releaseLock();
|
||||
await cancelPromise;
|
||||
return { done: true, value: undefined };
|
||||
},
|
||||
[Symbol.asyncIterator]() {
|
||||
return this;
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = CohereStream;
|
@@ -5,7 +5,6 @@ const { jsonParser } = require('../../express-common');
|
||||
const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY, OPENROUTER_HEADERS } = require('../../constants');
|
||||
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4, mergeObjectWithYaml, excludeKeysByYaml, color } = require('../../util');
|
||||
const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertAI21Messages, mergeMessages } = require('../../prompt-converters');
|
||||
const CohereStream = require('../../cohere-stream');
|
||||
|
||||
const { readSecret, SECRET_KEYS } = require('../secrets');
|
||||
const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sentencepieceTokenizers, TEXT_COMPLETION_MODELS } = require('../tokenizers');
|
||||
@@ -13,7 +12,8 @@ const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sente
|
||||
const API_OPENAI = 'https://api.openai.com/v1';
|
||||
const API_CLAUDE = 'https://api.anthropic.com/v1';
|
||||
const API_MISTRAL = 'https://api.mistral.ai/v1';
|
||||
const API_COHERE = 'https://api.cohere.ai/v1';
|
||||
const API_COHERE_V1 = 'https://api.cohere.ai/v1';
|
||||
const API_COHERE_V2 = 'https://api.cohere.ai/v2';
|
||||
const API_PERPLEXITY = 'https://api.perplexity.ai';
|
||||
const API_GROQ = 'https://api.groq.com/openai/v1';
|
||||
const API_MAKERSUITE = 'https://generativelanguage.googleapis.com';
|
||||
@@ -553,13 +553,14 @@ async function sendCohereRequest(request, response) {
|
||||
|
||||
try {
|
||||
const convertedHistory = convertCohereMessages(request.body.messages, request.body.char_name, request.body.user_name);
|
||||
const connectors = [];
|
||||
const tools = [];
|
||||
|
||||
const canDoWebSearch = !String(request.body.model).includes('c4ai-aya');
|
||||
if (request.body.websearch && canDoWebSearch) {
|
||||
connectors.push({
|
||||
id: 'web-search',
|
||||
if (Array.isArray(request.body.tools) && request.body.tools.length > 0) {
|
||||
tools.push(...request.body.tools);
|
||||
tools.forEach(tool => {
|
||||
if (tool?.function?.parameters?.$schema) {
|
||||
delete tool.function.parameters.$schema;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -567,9 +568,7 @@ async function sendCohereRequest(request, response) {
|
||||
const requestBody = {
|
||||
stream: Boolean(request.body.stream),
|
||||
model: request.body.model,
|
||||
message: convertedHistory.userPrompt,
|
||||
preamble: convertedHistory.systemPrompt,
|
||||
chat_history: convertedHistory.chatHistory,
|
||||
messages: convertedHistory.chatHistory,
|
||||
temperature: request.body.temperature,
|
||||
max_tokens: request.body.max_tokens,
|
||||
k: request.body.top_k,
|
||||
@@ -578,16 +577,13 @@ async function sendCohereRequest(request, response) {
|
||||
stop_sequences: request.body.stop,
|
||||
frequency_penalty: request.body.frequency_penalty,
|
||||
presence_penalty: request.body.presence_penalty,
|
||||
prompt_truncation: 'AUTO_PRESERVE_ORDER',
|
||||
connectors: connectors,
|
||||
documents: [],
|
||||
tools: tools,
|
||||
search_queries_only: false,
|
||||
};
|
||||
|
||||
const canDoSafetyMode = String(request.body.model).endsWith('08-2024');
|
||||
if (canDoSafetyMode) {
|
||||
requestBody.safety_mode = 'NONE';
|
||||
requestBody.safety_mode = 'OFF';
|
||||
}
|
||||
|
||||
console.log('Cohere request:', requestBody);
|
||||
@@ -603,11 +599,11 @@ async function sendCohereRequest(request, response) {
|
||||
timeout: 0,
|
||||
};
|
||||
|
||||
const apiUrl = API_COHERE + '/chat';
|
||||
const apiUrl = API_COHERE_V2 + '/chat';
|
||||
|
||||
if (request.body.stream) {
|
||||
const stream = await global.fetch(apiUrl, config);
|
||||
parseCohereStream(stream, request, response);
|
||||
const stream = await fetch(apiUrl, config);
|
||||
forwardFetchResponse(stream, response);
|
||||
} else {
|
||||
const generateResponse = await fetch(apiUrl, config);
|
||||
if (!generateResponse.ok) {
|
||||
@@ -658,7 +654,7 @@ router.post('/status', jsonParser, async function (request, response_getstatus_o
|
||||
headers = {};
|
||||
mergeObjectWithYaml(headers, request.body.custom_include_headers);
|
||||
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.COHERE) {
|
||||
api_url = API_COHERE;
|
||||
api_url = API_COHERE_V1;
|
||||
api_key_openai = readSecret(request.user.directories, SECRET_KEYS.COHERE);
|
||||
headers = {};
|
||||
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.ZEROONEAI) {
|
||||
|
@@ -277,38 +277,9 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, useTools,
|
||||
* @param {object[]} messages Array of messages
|
||||
* @param {string} charName Character name
|
||||
* @param {string} userName User name
|
||||
* @returns {{systemPrompt: string, chatHistory: object[], userPrompt: string}} Prompt for Cohere
|
||||
* @returns {{chatHistory: object[]}} Prompt for Cohere
|
||||
*/
|
||||
function convertCohereMessages(messages, charName = '', userName = '') {
|
||||
const roleMap = {
|
||||
'system': 'SYSTEM',
|
||||
'user': 'USER',
|
||||
'assistant': 'CHATBOT',
|
||||
};
|
||||
let systemPrompt = '';
|
||||
|
||||
// Collect all the system messages up until the first instance of a non-system message, and then remove them from the messages array.
|
||||
let i;
|
||||
for (i = 0; i < messages.length; i++) {
|
||||
if (messages[i].role !== 'system') {
|
||||
break;
|
||||
}
|
||||
// Append example names if not already done by the frontend (e.g. for group chats).
|
||||
if (userName && messages[i].name === 'example_user') {
|
||||
if (!messages[i].content.startsWith(`${userName}: `)) {
|
||||
messages[i].content = `${userName}: ${messages[i].content}`;
|
||||
}
|
||||
}
|
||||
if (charName && messages[i].name === 'example_assistant') {
|
||||
if (!messages[i].content.startsWith(`${charName}: `)) {
|
||||
messages[i].content = `${charName}: ${messages[i].content}`;
|
||||
}
|
||||
}
|
||||
systemPrompt += `${messages[i].content}\n\n`;
|
||||
}
|
||||
|
||||
messages.splice(0, i);
|
||||
|
||||
if (messages.length === 0) {
|
||||
messages.unshift({
|
||||
role: 'user',
|
||||
@@ -316,17 +287,45 @@ function convertCohereMessages(messages, charName = '', userName = '') {
|
||||
});
|
||||
}
|
||||
|
||||
const lastNonSystemMessageIndex = messages.findLastIndex(msg => msg.role === 'user' || msg.role === 'assistant');
|
||||
const userPrompt = messages.slice(lastNonSystemMessageIndex).map(msg => msg.content).join('\n\n') || PROMPT_PLACEHOLDER;
|
||||
|
||||
const chatHistory = messages.slice(0, lastNonSystemMessageIndex).map(msg => {
|
||||
return {
|
||||
role: roleMap[msg.role] || 'USER',
|
||||
message: msg.content,
|
||||
};
|
||||
messages.forEach((msg, index) => {
|
||||
// Tool calls require an assistent primer
|
||||
if (Array.isArray(msg.tool_calls)) {
|
||||
if (index > 0 && messages[index - 1].role === 'assistant') {
|
||||
msg.content = messages[index - 1].content;
|
||||
messages.splice(index - 1, 1);
|
||||
} else {
|
||||
msg.content = `I'm going to call the tool for that: ${msg.tool_calls.map(tc => tc?.function?.name).join(', ')}`;
|
||||
}
|
||||
}
|
||||
// No names support (who would've thought)
|
||||
if (msg.name) {
|
||||
if (msg.role == 'system' && msg.name == 'example_assistant') {
|
||||
if (charName && !msg.content.startsWith(`${charName}: `)) {
|
||||
msg.content = `${charName}: ${msg.content}`;
|
||||
}
|
||||
}
|
||||
if (msg.role == 'system' && msg.name == 'example_user') {
|
||||
if (userName && !msg.content.startsWith(`${userName}: `)) {
|
||||
msg.content = `${userName}: ${msg.content}`;
|
||||
}
|
||||
}
|
||||
if (msg.role !== 'system' && !msg.content.startsWith(`${msg.name}: `)) {
|
||||
msg.content = `${msg.name}: ${msg.content}`;
|
||||
}
|
||||
delete msg.name;
|
||||
}
|
||||
});
|
||||
|
||||
return { systemPrompt: systemPrompt.trim(), chatHistory, userPrompt };
|
||||
// A prompt should end with a user/tool message
|
||||
if (!['user', 'tool'].includes(messages[messages.length - 1].role)) {
|
||||
const userPlaceholder = getConfigValue('cohere.userPlaceholder', PROMPT_PLACEHOLDER || 'Continue');
|
||||
messages.push({
|
||||
role: 'user',
|
||||
content: userPlaceholder,
|
||||
});
|
||||
}
|
||||
|
||||
return { chatHistory: messages };
|
||||
}
|
||||
|
||||
/**
|
||||
|
Reference in New Issue
Block a user