diff --git a/default/config.yaml b/default/config.yaml index b3057451a..902f8a4d5 100644 --- a/default/config.yaml +++ b/default/config.yaml @@ -155,5 +155,9 @@ claude: # (e.g {{random}} macro or lorebooks not as in-chat injections). # Otherwise, you'll just waste money on cache misses. enableSystemPromptCache: false +# -- COHERE API CONFIGURATION -- +cohere: + # A placeholder prompt to be used when the message array doesn't end with a user message + userPlaceholder: "Continue" # -- SERVER PLUGIN CONFIGURATION -- enableServerPlugins: false diff --git a/public/index.html b/public/index.html index afed83cab..3abe9f3bb 100644 --- a/public/index.html +++ b/public/index.html @@ -595,17 +595,6 @@ -
- -
- - Allow the model to use the web-search connector. - -
-
Temperature diff --git a/public/locales/ar-sa.json b/public/locales/ar-sa.json index 90ef674cf..1716ffdef 100644 --- a/public/locales/ar-sa.json +++ b/public/locales/ar-sa.json @@ -56,9 +56,6 @@ "Max Response Length (tokens)": "الحد الأقصى لطول الاستجابة (الرموز,الحرف)", "Multiple swipes per generation": "الضربات الشديدة المتعددة لكل جيل", "Enable OpenAI completion streaming": "تمكين بث الاكتمال من OpenAI", - "Enable Cohere web-search connector": "تمكين موصل بحث الويب Cohere", - "Web-search": "البحث في الويب", - "Allow the model to use the web-search connector.": "اسمح للنموذج باستخدام موصل بحث الويب.", "Frequency Penalty": "عقوبة التكرار", "Presence Penalty": "عقوبة الوجود", "Count Penalty": "عد ضربة جزاء", diff --git a/public/locales/de-de.json b/public/locales/de-de.json index ac4e63298..19d048378 100644 --- a/public/locales/de-de.json +++ b/public/locales/de-de.json @@ -56,9 +56,6 @@ "Max Response Length (tokens)": "Maximale Antwortlänge (Tokens)", "Multiple swipes per generation": "Mehrere Swipes pro Generation", "Enable OpenAI completion streaming": "OpenAI-Vervollständigungsstreaming aktivieren", - "Enable Cohere web-search connector": "Cohere-Websuch-Connector aktivieren", - "Web-search": "Web-Suche", - "Allow the model to use the web-search connector.": "Erlauben Sie dem Modell, den Websuch-Connector zu verwenden.", "Frequency Penalty": "Frequenzstrafe", "Presence Penalty": "Präsenzstrafe", "Count Penalty": "Strafe zählen", diff --git a/public/locales/es-es.json b/public/locales/es-es.json index b7b746e5d..722595459 100644 --- a/public/locales/es-es.json +++ b/public/locales/es-es.json @@ -56,9 +56,6 @@ "Max Response Length (tokens)": "Longitud máxima de respuesta (tokens)", "Multiple swipes per generation": "Múltiples golpes por generación", "Enable OpenAI completion streaming": "Activar streaming de completado de OpenAI", - "Enable Cohere web-search connector": "Habilitar el conector de búsqueda web de Cohere", - "Web-search": "Búsqueda Web", - "Allow the model to use the web-search connector.": "Permita que el modelo utilice el conector de búsqueda web.", "Frequency Penalty": "Penalización de frecuencia", "Presence Penalty": "Penalización de presencia", "Count Penalty": "Penalización de conteo", diff --git a/public/locales/fr-fr.json b/public/locales/fr-fr.json index e069895b6..551e484c4 100644 --- a/public/locales/fr-fr.json +++ b/public/locales/fr-fr.json @@ -56,9 +56,6 @@ "Max Response Length (tokens)": "Longueur maximale de la réponse (tokens)", "Multiple swipes per generation": "Plusieurs balayages par génération", "Enable OpenAI completion streaming": "Activer le streaming de complétion OpenAI", - "Enable Cohere web-search connector": "Activer le connecteur de recherche Web Cohere", - "Web-search": "Recherche Internet", - "Allow the model to use the web-search connector.": "Autorisez le modèle à utiliser le connecteur de recherche Web.", "Frequency Penalty": "Pénalité de fréquence", "Presence Penalty": "Pénalité de présence", "Count Penalty": "Pénalité de décompte", diff --git a/public/locales/is-is.json b/public/locales/is-is.json index 634a7c2c2..36cb62965 100644 --- a/public/locales/is-is.json +++ b/public/locales/is-is.json @@ -56,9 +56,6 @@ "Max Response Length (tokens)": "Hámarks lengd svörunar (í táknum)", "Multiple swipes per generation": "Mörg högg á hverja kynslóð", "Enable OpenAI completion streaming": "Virkja OpenAI klárastreymi", - "Enable Cohere web-search connector": "Virkja Cohere vefleitartengi", - "Web-search": "Vefleit", - "Allow the model to use the web-search connector.": "Leyfðu líkaninu að nota vefleitartengið.", "Frequency Penalty": "Tíðnarefning", "Presence Penalty": "Tilkoma refning", "Count Penalty": "Telja víti", diff --git a/public/locales/it-it.json b/public/locales/it-it.json index bab851988..3bdfef772 100644 --- a/public/locales/it-it.json +++ b/public/locales/it-it.json @@ -56,9 +56,6 @@ "Max Response Length (tokens)": "Lunghezza massima della risposta (token)", "Multiple swipes per generation": "Più passaggi per generazione", "Enable OpenAI completion streaming": "Abilita lo streaming di completamento OpenAI", - "Enable Cohere web-search connector": "Abilita il connettore di ricerca web Cohere", - "Web-search": "Ricerca sul web", - "Allow the model to use the web-search connector.": "Consenti al modello di utilizzare il connettore di ricerca web.", "Frequency Penalty": "Penalità di frequenza", "Presence Penalty": "Penalità di presenza", "Count Penalty": "Conte Penalità", diff --git a/public/locales/ja-jp.json b/public/locales/ja-jp.json index 88453ad0c..f7652d2d8 100644 --- a/public/locales/ja-jp.json +++ b/public/locales/ja-jp.json @@ -56,9 +56,6 @@ "Max Response Length (tokens)": "最大応答長(トークン数)", "Multiple swipes per generation": "世代ごとに複数のスワイプ", "Enable OpenAI completion streaming": "OpenAIの完了ストリーミングを有効にする", - "Enable Cohere web-search connector": "Cohereウェブ検索コネクタを有効にする", - "Web-search": "ウェブ検索", - "Allow the model to use the web-search connector.": "モデルが Web 検索コネクタを使用できるようにします。", "Frequency Penalty": "頻度ペナルティ", "Presence Penalty": "存在ペナルティ", "Count Penalty": "カウントペナルティ", diff --git a/public/locales/ko-kr.json b/public/locales/ko-kr.json index 919209f92..5c5a73f7c 100644 --- a/public/locales/ko-kr.json +++ b/public/locales/ko-kr.json @@ -56,9 +56,6 @@ "Max Response Length (tokens)": "최대 응답 길이 (토큰)", "Multiple swipes per generation": "세대당 다중 스와이프", "Enable OpenAI completion streaming": "OpenAI 완성 스트리밍 활성화", - "Enable Cohere web-search connector": "Cohere 웹 검색 커넥터 활성화", - "Web-search": "웹 서핑", - "Allow the model to use the web-search connector.": "모델이 웹 검색 커넥터를 사용하도록 허용합니다.", "Frequency Penalty": "빈도 패널티", "Presence Penalty": "존재 패널티", "Count Penalty": "카운트 페널티", diff --git a/public/locales/nl-nl.json b/public/locales/nl-nl.json index a0e7e8f9b..e0a43e09f 100644 --- a/public/locales/nl-nl.json +++ b/public/locales/nl-nl.json @@ -56,9 +56,6 @@ "Max Response Length (tokens)": "Maximale lengte van het antwoord (tokens)", "Multiple swipes per generation": "Meerdere swipes per generatie", "Enable OpenAI completion streaming": "OpenAI voltooiingsstreaming inschakelen", - "Enable Cohere web-search connector": "Schakel de Cohere-webzoekconnector in", - "Web-search": "Web-zoeken", - "Allow the model to use the web-search connector.": "Sta toe dat het model de webzoekconnector gebruikt.", "Frequency Penalty": "Frequentieboete", "Presence Penalty": "Aanwezigheidsboete", "Count Penalty": "Tel straf", diff --git a/public/locales/pt-pt.json b/public/locales/pt-pt.json index 1ecfd5222..6087e5db3 100644 --- a/public/locales/pt-pt.json +++ b/public/locales/pt-pt.json @@ -56,9 +56,6 @@ "Max Response Length (tokens)": "Comprimento Máximo da Resposta (tokens)", "Multiple swipes per generation": "Vários furtos por geração", "Enable OpenAI completion streaming": "Ativar streaming de conclusão do OpenAI", - "Enable Cohere web-search connector": "Ativar o conector de pesquisa na Web Cohere", - "Web-search": "Pesquisa na internet", - "Allow the model to use the web-search connector.": "Permita que o modelo use o conector de pesquisa na Web.", "Frequency Penalty": "Pena de Frequência", "Presence Penalty": "Pena de Presença", "Count Penalty": "Contar penalidade", diff --git a/public/locales/ru-ru.json b/public/locales/ru-ru.json index a633db208..43ce1fe35 100644 --- a/public/locales/ru-ru.json +++ b/public/locales/ru-ru.json @@ -772,7 +772,6 @@ "Type a message, or /? for help": "Введите сообщение, или /? для получения справки", "Welcome to SillyTavern!": "Добро пожаловать в SillyTavern!", "Won't be shared with the character card on export.": "Не попадут в карточку персонажа при экспорте.", - "Web-search": "Веб-поиск", "Persona Name:": "Имя персоны:", "User first message": "Первое сообщение пользователя", "extension_token_counter": "Токенов:", @@ -1200,8 +1199,6 @@ "Streaming_desc": "Выводить текст последовательно по мере его генерации.\rЕсли параметр выключен, ответы будут отображаться сразу целиком, и только после полного завершения генерации.", "Max prompt cost:": "Max prompt cost:", "TFS": "TFS", - "Enable Cohere web-search connector": "Enable Cohere web-search connector", - "Allow the model to use the web-search connector.": "Allow the model to use the web-search connector.", "Count Penalty": "Count Penalty", "Min P": "Min P", "NSFW": "NSFW", diff --git a/public/locales/uk-ua.json b/public/locales/uk-ua.json index 7fa6e2ea3..5152fa611 100644 --- a/public/locales/uk-ua.json +++ b/public/locales/uk-ua.json @@ -56,9 +56,6 @@ "Max Response Length (tokens)": "Довжина відповіді (токени)", "Multiple swipes per generation": "Кілька свайпів за покоління", "Enable OpenAI completion streaming": "Увімкнути потокове завершення OpenAI", - "Enable Cohere web-search connector": "Увімкнути конектор веб-пошуку Cohere", - "Web-search": "Веб-пошук", - "Allow the model to use the web-search connector.": "Дозвольте моделі використовувати конектор веб-пошуку.", "Frequency Penalty": "Штраф за частоту", "Presence Penalty": "Штраф за наявність", "Count Penalty": "Рахувати пенальті", diff --git a/public/locales/vi-vn.json b/public/locales/vi-vn.json index 268fdce77..79e2594eb 100644 --- a/public/locales/vi-vn.json +++ b/public/locales/vi-vn.json @@ -56,9 +56,6 @@ "Max Response Length (tokens)": "Độ dài phản hồi tối đa (token)", "Multiple swipes per generation": "Vuốt nhiều lần trong một lần tạo", "Enable OpenAI completion streaming": "Bật streaming của OpenAI", - "Enable Cohere web-search connector": "Bật web tìm kiếm của Cohere", - "Web-search": "Tìm kiếm trên web", - "Allow the model to use the web-search connector.": "Cho phép model sử dụng trình kết nối tìm kiếm trên web.", "Frequency Penalty": "Frequency Penalty", "Presence Penalty": "Presence Penalty", "Count Penalty": "Count Penalty", diff --git a/public/locales/zh-cn.json b/public/locales/zh-cn.json index df8b7f831..fc6fae29c 100644 --- a/public/locales/zh-cn.json +++ b/public/locales/zh-cn.json @@ -56,9 +56,6 @@ "Max Response Length (tokens)": "最大回复长度(以词符数计)", "Multiple swipes per generation": "每次生成多个备选回复", "Enable OpenAI completion streaming": "启用OpenAI文本补全流式传输", - "Enable Cohere web-search connector": "启用 Cohere 网络搜索连接器", - "Web-search": "联网搜索", - "Allow the model to use the web-search connector.": "允许模型使用联网搜索。", "Frequency Penalty": "频率惩罚", "Presence Penalty": "存在惩罚", "Count Penalty": "计数惩罚", diff --git a/public/locales/zh-tw.json b/public/locales/zh-tw.json index 59270882f..6e853966a 100644 --- a/public/locales/zh-tw.json +++ b/public/locales/zh-tw.json @@ -56,9 +56,6 @@ "Max Response Length (tokens)": "最大回應長度(符記數)", "Multiple swipes per generation": "每次生成多次滑動", "Enable OpenAI completion streaming": "啟用 OpenAI 補充串流", - "Enable Cohere web-search connector": "啟用 Cohere 網頁搜尋連接器", - "Web-search": "網頁搜尋", - "Allow the model to use the web-search connector.": "允許模型使用網頁搜尋連接器", "Frequency Penalty": "頻率懲罰", "Presence Penalty": "存在懲罰", "Count Penalty": "計數懲罰", diff --git a/public/script.js b/public/script.js index 8fc833666..cbc21dbfd 100644 --- a/public/script.js +++ b/public/script.js @@ -5436,7 +5436,7 @@ function extractMessageFromData(data) { case 'novel': return data.output; case 'openai': - return data?.choices?.[0]?.message?.content ?? data?.choices?.[0]?.text ?? data?.text ?? ''; + return data?.choices?.[0]?.message?.content ?? data?.choices?.[0]?.text ?? data?.text ?? data?.message?.tool_plan ?? data?.message?.content?.[0]?.text ?? ''; default: return ''; } diff --git a/public/scripts/openai.js b/public/scripts/openai.js index 90334f299..94a537c07 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -225,7 +225,6 @@ const default_settings = { top_a_openai: 0, repetition_penalty_openai: 1, stream_openai: false, - websearch_cohere: false, openai_max_context: max_4k, openai_max_tokens: 300, wrap_in_quotes: false, @@ -302,7 +301,6 @@ const oai_settings = { top_a_openai: 0, repetition_penalty_openai: 1, stream_openai: false, - websearch_cohere: false, openai_max_context: max_4k, openai_max_tokens: 300, wrap_in_quotes: false, @@ -1847,7 +1845,6 @@ async function sendOpenAIRequest(type, messages, signal) { generate_data['frequency_penalty'] = Math.min(Math.max(Number(oai_settings.freq_pen_openai), 0), 1); generate_data['presence_penalty'] = Math.min(Math.max(Number(oai_settings.pres_pen_openai), 0), 1); generate_data['stop'] = getCustomStoppingStrings(5); - generate_data['websearch'] = oai_settings.websearch_cohere; } if (isPerplexity) { @@ -1980,8 +1977,10 @@ function getStreamingReply(data) { return data?.delta?.text || ''; } else if (oai_settings.chat_completion_source == chat_completion_sources.MAKERSUITE) { return data?.candidates?.[0]?.content?.parts?.[0]?.text || ''; + } else if (oai_settings.chat_completion_source == chat_completion_sources.COHERE) { + return data?.delta?.message?.content?.text || data?.delta?.message?.tool_plan || ''; } else { - return data.choices[0]?.delta?.content ?? data.choices[0]?.message?.content ?? data.choices[0]?.text ?? ''; + return data.choices?.[0]?.delta?.content ?? data.choices?.[0]?.message?.content ?? data.choices?.[0]?.text ?? ''; } } @@ -2857,7 +2856,6 @@ function loadOpenAISettings(data, settings) { oai_settings.min_p_openai = settings.min_p_openai ?? default_settings.min_p_openai; oai_settings.repetition_penalty_openai = settings.repetition_penalty_openai ?? default_settings.repetition_penalty_openai; oai_settings.stream_openai = settings.stream_openai ?? default_settings.stream_openai; - oai_settings.websearch_cohere = settings.websearch_cohere ?? default_settings.websearch_cohere; oai_settings.openai_max_context = settings.openai_max_context ?? default_settings.openai_max_context; oai_settings.openai_max_tokens = settings.openai_max_tokens ?? default_settings.openai_max_tokens; oai_settings.bias_preset_selected = settings.bias_preset_selected ?? default_settings.bias_preset_selected; @@ -2931,7 +2929,6 @@ function loadOpenAISettings(data, settings) { if (settings.use_makersuite_sysprompt !== undefined) oai_settings.use_makersuite_sysprompt = !!settings.use_makersuite_sysprompt; if (settings.use_alt_scale !== undefined) { oai_settings.use_alt_scale = !!settings.use_alt_scale; updateScaleForm(); } $('#stream_toggle').prop('checked', oai_settings.stream_openai); - $('#websearch_toggle').prop('checked', oai_settings.websearch_cohere); $('#api_url_scale').val(oai_settings.api_url_scale); $('#openai_proxy_password').val(oai_settings.proxy_password); $('#claude_assistant_prefill').val(oai_settings.assistant_prefill); @@ -3258,7 +3255,6 @@ async function saveOpenAIPreset(name, settings, triggerUi = true) { personality_format: settings.personality_format, group_nudge_prompt: settings.group_nudge_prompt, stream_openai: settings.stream_openai, - websearch_cohere: settings.websearch_cohere, prompts: settings.prompts, prompt_order: settings.prompt_order, api_url_scale: settings.api_url_scale, @@ -3682,7 +3678,6 @@ function onSettingsPresetChange() { personality_format: ['#personality_format_textarea', 'personality_format', false], group_nudge_prompt: ['#group_nudge_prompt_textarea', 'group_nudge_prompt', false], stream_openai: ['#stream_toggle', 'stream_openai', true], - websearch_cohere: ['#websearch_toggle', 'websearch_cohere', true], prompts: ['', 'prompts', false], prompt_order: ['', 'prompt_order', false], api_url_scale: ['#api_url_scale', 'api_url_scale', false], @@ -4846,11 +4841,6 @@ export function initOpenAI() { saveSettingsDebounced(); }); - $('#websearch_toggle').on('change', function () { - oai_settings.websearch_cohere = !!$('#websearch_toggle').prop('checked'); - saveSettingsDebounced(); - }); - $('#wrap_in_quotes').on('change', function () { oai_settings.wrap_in_quotes = !!$('#wrap_in_quotes').prop('checked'); saveSettingsDebounced(); diff --git a/public/scripts/sse-stream.js b/public/scripts/sse-stream.js index 620caf3c5..dbe481e78 100644 --- a/public/scripts/sse-stream.js +++ b/public/scripts/sse-stream.js @@ -108,9 +108,21 @@ function getDelay(s) { * @returns {AsyncGenerator<{data: object, chunk: string}>} The parsed data and the chunk to be sent. */ async function* parseStreamData(json) { + // Cohere + if (typeof json.delta.message === 'object' && ['tool-plan-delta', 'content-delta'].includes(json.type)) { + const text = json?.delta?.message?.content?.text ?? ''; + for (let i = 0; i < text.length; i++) { + const str = json.delta.message.content.text[i]; + yield { + data: { ...json, delta: { message: { content: { text: str } } } }, + chunk: str, + }; + } + return; + } // Claude - if (typeof json.delta === 'object') { - if (typeof json.delta.text === 'string' && json.delta.text.length > 0) { + if (typeof json.delta === 'object' && typeof json.delta.text === 'string') { + if (json.delta.text.length > 0) { for (let i = 0; i < json.delta.text.length; i++) { const str = json.delta.text[i]; yield { diff --git a/public/scripts/tool-calling.js b/public/scripts/tool-calling.js index b56ab2db6..06c2b97a9 100644 --- a/public/scripts/tool-calling.js +++ b/public/scripts/tool-calling.js @@ -381,6 +381,22 @@ export class ToolManager { } } } + const cohereToolEvents = ['message-start', 'tool-call-start', 'tool-call-delta', 'tool-call-end']; + if (cohereToolEvents.includes(parsed?.type) && typeof parsed?.delta?.message === 'object') { + const choiceIndex = 0; + const toolCallIndex = parsed?.index ?? 0; + + if (!Array.isArray(toolCalls[choiceIndex])) { + toolCalls[choiceIndex] = []; + } + + if (toolCalls[choiceIndex][toolCallIndex] === undefined) { + toolCalls[choiceIndex][toolCallIndex] = {}; + } + + const targetToolCall = toolCalls[choiceIndex][toolCallIndex]; + ToolManager.#applyToolCallDelta(targetToolCall, parsed.delta.message); + } if (typeof parsed?.content_block === 'object') { const choiceIndex = 0; const toolCallIndex = parsed?.index ?? 0; @@ -483,6 +499,7 @@ export class ToolManager { chat_completion_sources.CLAUDE, chat_completion_sources.OPENROUTER, chat_completion_sources.GROQ, + chat_completion_sources.COHERE, ]; return supportedSources.includes(oai_settings.chat_completion_source); } @@ -509,7 +526,15 @@ export class ToolManager { // Parsed tool calls from streaming data if (Array.isArray(data) && data.length > 0 && Array.isArray(data[0])) { - return isClaudeToolCall(data[0]) ? data[0].filter(x => x).map(convertClaudeToolCall) : data[0]; + if (isClaudeToolCall(data[0])) { + return data[0].filter(x => x).map(convertClaudeToolCall); + } + + if (typeof data[0]?.[0]?.tool_calls === 'object') { + return Array.isArray(data[0]?.[0]?.tool_calls) ? data[0][0].tool_calls : [data[0][0].tool_calls]; + } + + return data[0]; } // Parsed tool calls from non-streaming data @@ -530,6 +555,11 @@ export class ToolManager { return content; } } + + // Cohere tool calls + if (typeof data?.message?.tool_calls === 'object') { + return Array.isArray(data?.message?.tool_calls) ? data.message.tool_calls : [data.message.tool_calls]; + } } /** diff --git a/src/cohere-stream.js b/src/cohere-stream.js deleted file mode 100644 index d59ecad5e..000000000 --- a/src/cohere-stream.js +++ /dev/null @@ -1,126 +0,0 @@ -const DATA_PREFIX = 'data:'; - -/** - * Borrowed from Cohere SDK (MIT License) - * https://github.com/cohere-ai/cohere-typescript/blob/main/src/core/streaming-fetcher/Stream.ts - * Copyright (c) 2021 Cohere - */ -class CohereStream { - /** @type {ReadableStream} */ - stream; - /** @type {string} */ - prefix; - /** @type {string} */ - messageTerminator; - /** @type {string|undefined} */ - streamTerminator; - /** @type {AbortController} */ - controller = new AbortController(); - - constructor({ stream, eventShape }) { - this.stream = stream; - if (eventShape.type === 'sse') { - this.prefix = DATA_PREFIX; - this.messageTerminator = '\n'; - this.streamTerminator = eventShape.streamTerminator; - } else { - this.messageTerminator = eventShape.messageTerminator; - } - } - - async *iterMessages() { - const stream = readableStreamAsyncIterable(this.stream); - let buf = ''; - let prefixSeen = false; - let parsedAnyMessages = false; - for await (const chunk of stream) { - buf += this.decodeChunk(chunk); - - let terminatorIndex; - // Parse the chunk into as many messages as possible - while ((terminatorIndex = buf.indexOf(this.messageTerminator)) >= 0) { - // Extract the line from the buffer - let line = buf.slice(0, terminatorIndex + 1); - buf = buf.slice(terminatorIndex + 1); - - // Skip empty lines - if (line.length === 0) { - continue; - } - - // Skip the chunk until the prefix is found - if (!prefixSeen && this.prefix != null) { - const prefixIndex = line.indexOf(this.prefix); - if (prefixIndex === -1) { - continue; - } - prefixSeen = true; - line = line.slice(prefixIndex + this.prefix.length); - } - - // If the stream terminator is present, return - if (this.streamTerminator != null && line.includes(this.streamTerminator)) { - return; - } - - // Otherwise, yield message from the prefix to the terminator - const message = JSON.parse(line); - yield message; - prefixSeen = false; - parsedAnyMessages = true; - } - } - - if (!parsedAnyMessages && buf.length > 0) { - try { - yield JSON.parse(buf); - } catch (e) { - console.error('Error parsing message:', e); - } - } - } - - async *[Symbol.asyncIterator]() { - for await (const message of this.iterMessages()) { - yield message; - } - } - - decodeChunk(chunk) { - const decoder = new TextDecoder('utf8'); - return decoder.decode(chunk); - } -} - -function readableStreamAsyncIterable(stream) { - if (stream[Symbol.asyncIterator]) { - return stream; - } - - const reader = stream.getReader(); - return { - async next() { - try { - const result = await reader.read(); - if (result?.done) { - reader.releaseLock(); - } // release lock when stream becomes closed - return result; - } catch (e) { - reader.releaseLock(); // release lock when stream becomes errored - throw e; - } - }, - async return() { - const cancelPromise = reader.cancel(); - reader.releaseLock(); - await cancelPromise; - return { done: true, value: undefined }; - }, - [Symbol.asyncIterator]() { - return this; - }, - }; -} - -module.exports = CohereStream; diff --git a/src/endpoints/backends/chat-completions.js b/src/endpoints/backends/chat-completions.js index d2f963007..c190e6f71 100644 --- a/src/endpoints/backends/chat-completions.js +++ b/src/endpoints/backends/chat-completions.js @@ -5,7 +5,6 @@ const { jsonParser } = require('../../express-common'); const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY, OPENROUTER_HEADERS } = require('../../constants'); const { forwardFetchResponse, getConfigValue, tryParse, uuidv4, mergeObjectWithYaml, excludeKeysByYaml, color } = require('../../util'); const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertAI21Messages, mergeMessages } = require('../../prompt-converters'); -const CohereStream = require('../../cohere-stream'); const { readSecret, SECRET_KEYS } = require('../secrets'); const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sentencepieceTokenizers, TEXT_COMPLETION_MODELS } = require('../tokenizers'); @@ -13,7 +12,8 @@ const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sente const API_OPENAI = 'https://api.openai.com/v1'; const API_CLAUDE = 'https://api.anthropic.com/v1'; const API_MISTRAL = 'https://api.mistral.ai/v1'; -const API_COHERE = 'https://api.cohere.ai/v1'; +const API_COHERE_V1 = 'https://api.cohere.ai/v1'; +const API_COHERE_V2 = 'https://api.cohere.ai/v2'; const API_PERPLEXITY = 'https://api.perplexity.ai'; const API_GROQ = 'https://api.groq.com/openai/v1'; const API_MAKERSUITE = 'https://generativelanguage.googleapis.com'; @@ -553,13 +553,14 @@ async function sendCohereRequest(request, response) { try { const convertedHistory = convertCohereMessages(request.body.messages, request.body.char_name, request.body.user_name); - const connectors = []; const tools = []; - const canDoWebSearch = !String(request.body.model).includes('c4ai-aya'); - if (request.body.websearch && canDoWebSearch) { - connectors.push({ - id: 'web-search', + if (Array.isArray(request.body.tools) && request.body.tools.length > 0) { + tools.push(...request.body.tools); + tools.forEach(tool => { + if (tool?.function?.parameters?.$schema) { + delete tool.function.parameters.$schema; + } }); } @@ -567,9 +568,7 @@ async function sendCohereRequest(request, response) { const requestBody = { stream: Boolean(request.body.stream), model: request.body.model, - message: convertedHistory.userPrompt, - preamble: convertedHistory.systemPrompt, - chat_history: convertedHistory.chatHistory, + messages: convertedHistory.chatHistory, temperature: request.body.temperature, max_tokens: request.body.max_tokens, k: request.body.top_k, @@ -578,16 +577,13 @@ async function sendCohereRequest(request, response) { stop_sequences: request.body.stop, frequency_penalty: request.body.frequency_penalty, presence_penalty: request.body.presence_penalty, - prompt_truncation: 'AUTO_PRESERVE_ORDER', - connectors: connectors, documents: [], tools: tools, - search_queries_only: false, }; const canDoSafetyMode = String(request.body.model).endsWith('08-2024'); if (canDoSafetyMode) { - requestBody.safety_mode = 'NONE'; + requestBody.safety_mode = 'OFF'; } console.log('Cohere request:', requestBody); @@ -603,11 +599,11 @@ async function sendCohereRequest(request, response) { timeout: 0, }; - const apiUrl = API_COHERE + '/chat'; + const apiUrl = API_COHERE_V2 + '/chat'; if (request.body.stream) { - const stream = await global.fetch(apiUrl, config); - parseCohereStream(stream, request, response); + const stream = await fetch(apiUrl, config); + forwardFetchResponse(stream, response); } else { const generateResponse = await fetch(apiUrl, config); if (!generateResponse.ok) { @@ -658,7 +654,7 @@ router.post('/status', jsonParser, async function (request, response_getstatus_o headers = {}; mergeObjectWithYaml(headers, request.body.custom_include_headers); } else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.COHERE) { - api_url = API_COHERE; + api_url = API_COHERE_V1; api_key_openai = readSecret(request.user.directories, SECRET_KEYS.COHERE); headers = {}; } else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.ZEROONEAI) { diff --git a/src/prompt-converters.js b/src/prompt-converters.js index 23dcd05c9..a0c90c148 100644 --- a/src/prompt-converters.js +++ b/src/prompt-converters.js @@ -277,38 +277,9 @@ function convertClaudeMessages(messages, prefillString, useSysPrompt, useTools, * @param {object[]} messages Array of messages * @param {string} charName Character name * @param {string} userName User name - * @returns {{systemPrompt: string, chatHistory: object[], userPrompt: string}} Prompt for Cohere + * @returns {{chatHistory: object[]}} Prompt for Cohere */ function convertCohereMessages(messages, charName = '', userName = '') { - const roleMap = { - 'system': 'SYSTEM', - 'user': 'USER', - 'assistant': 'CHATBOT', - }; - let systemPrompt = ''; - - // Collect all the system messages up until the first instance of a non-system message, and then remove them from the messages array. - let i; - for (i = 0; i < messages.length; i++) { - if (messages[i].role !== 'system') { - break; - } - // Append example names if not already done by the frontend (e.g. for group chats). - if (userName && messages[i].name === 'example_user') { - if (!messages[i].content.startsWith(`${userName}: `)) { - messages[i].content = `${userName}: ${messages[i].content}`; - } - } - if (charName && messages[i].name === 'example_assistant') { - if (!messages[i].content.startsWith(`${charName}: `)) { - messages[i].content = `${charName}: ${messages[i].content}`; - } - } - systemPrompt += `${messages[i].content}\n\n`; - } - - messages.splice(0, i); - if (messages.length === 0) { messages.unshift({ role: 'user', @@ -316,17 +287,45 @@ function convertCohereMessages(messages, charName = '', userName = '') { }); } - const lastNonSystemMessageIndex = messages.findLastIndex(msg => msg.role === 'user' || msg.role === 'assistant'); - const userPrompt = messages.slice(lastNonSystemMessageIndex).map(msg => msg.content).join('\n\n') || PROMPT_PLACEHOLDER; - - const chatHistory = messages.slice(0, lastNonSystemMessageIndex).map(msg => { - return { - role: roleMap[msg.role] || 'USER', - message: msg.content, - }; + messages.forEach((msg, index) => { + // Tool calls require an assistent primer + if (Array.isArray(msg.tool_calls)) { + if (index > 0 && messages[index - 1].role === 'assistant') { + msg.content = messages[index - 1].content; + messages.splice(index - 1, 1); + } else { + msg.content = `I'm going to call the tool for that: ${msg.tool_calls.map(tc => tc?.function?.name).join(', ')}`; + } + } + // No names support (who would've thought) + if (msg.name) { + if (msg.role == 'system' && msg.name == 'example_assistant') { + if (charName && !msg.content.startsWith(`${charName}: `)) { + msg.content = `${charName}: ${msg.content}`; + } + } + if (msg.role == 'system' && msg.name == 'example_user') { + if (userName && !msg.content.startsWith(`${userName}: `)) { + msg.content = `${userName}: ${msg.content}`; + } + } + if (msg.role !== 'system' && !msg.content.startsWith(`${msg.name}: `)) { + msg.content = `${msg.name}: ${msg.content}`; + } + delete msg.name; + } }); - return { systemPrompt: systemPrompt.trim(), chatHistory, userPrompt }; + // A prompt should end with a user/tool message + if (!['user', 'tool'].includes(messages[messages.length - 1].role)) { + const userPlaceholder = getConfigValue('cohere.userPlaceholder', PROMPT_PLACEHOLDER || 'Continue'); + messages.push({ + role: 'user', + content: userPlaceholder, + }); + } + + return { chatHistory: messages }; } /**