From 42e6da4a361121f5c8c7ec1a0fafcb144c2f0ee4 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Sat, 26 Aug 2023 00:12:11 +0300 Subject: [PATCH] Add support of stop strings to OpenAI / Claude --- public/scripts/openai.js | 4 ++++ public/scripts/power-user.js | 17 ++++++++++++++--- server.js | 13 ++++++++++++- 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/public/scripts/openai.js b/public/scripts/openai.js index 4b3af67fd..7351c94a0 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -34,6 +34,7 @@ import { } from "./PromptManager.js"; import { + getCustomStoppingStrings, persona_description_positions, power_user, } from "./power-user.js"; @@ -120,6 +121,7 @@ const j2_max_topk = 10.0; const j2_max_freq = 5.0; const j2_max_pres = 5.0; const openrouter_website_model = 'OR_Website'; +const openai_max_stop_strings = 4; let biasCache = undefined; let model_list = []; @@ -1138,6 +1140,7 @@ async function sendOpenAIRequest(type, openai_msgs_tosend, signal) { "max_tokens": oai_settings.openai_max_tokens, "stream": stream, "logit_bias": logit_bias, + "stop": getCustomStoppingStrings(openai_max_stop_strings), }; // Proxy is only supported for Claude and OpenAI @@ -1151,6 +1154,7 @@ async function sendOpenAIRequest(type, openai_msgs_tosend, signal) { generate_data['use_claude'] = true; generate_data['top_k'] = Number(oai_settings.top_k_openai); generate_data['exclude_assistant'] = oai_settings.exclude_assistant; + generate_data['stop'] = getCustomStoppingStrings(); // Claude shouldn't have limits on stop strings. // Don't add a prefill on quiet gens (summarization) if (!isQuiet && !oai_settings.exclude_assistant) { generate_data['assistant_prefill'] = substituteParams(oai_settings.assistant_prefill); diff --git a/public/scripts/power-user.js b/public/scripts/power-user.js index 9ab1d964d..3e560c2ba 100644 --- a/public/scripts/power-user.js +++ b/public/scripts/power-user.js @@ -1537,8 +1537,19 @@ function setAvgBG() { } -export function getCustomStoppingStrings() { + +/** + * Gets the custom stopping strings from the power user settings. + * @param {number | undefined} limit Number of strings to return. If undefined, returns all strings. + * @returns {string[]} An array of custom stopping strings + */ +export function getCustomStoppingStrings(limit = undefined) { try { + // If there's no custom stopping strings, return an empty array + if (!power_user.custom_stopping_strings) { + return []; + } + // Parse the JSON string const strings = JSON.parse(power_user.custom_stopping_strings); @@ -1547,8 +1558,8 @@ export function getCustomStoppingStrings() { return []; } - // Make sure all the elements are strings - return strings.filter((s) => typeof s === 'string'); + // Make sure all the elements are strings. Apply the limit. + return strings.filter((s) => typeof s === 'string').slice(0, limit); } catch (error) { // If there's an error, return an empty array console.warn('Error parsing custom stopping strings:', error); diff --git a/server.js b/server.js index a0f818baa..ba71100e0 100644 --- a/server.js +++ b/server.js @@ -3391,6 +3391,12 @@ async function sendClaudeRequest(request, response) { } console.log('Claude request:', requestPrompt); + const stop_sequences = ["\n\nHuman:", "\n\nSystem:", "\n\nAssistant:"]; + + // Add custom stop sequences + if (Array.isArray(request.body.stop)) { + stop_sequences.push(...request.body.stop); + } const generateResponse = await fetch(api_url + '/complete', { method: "POST", @@ -3399,7 +3405,7 @@ async function sendClaudeRequest(request, response) { prompt: requestPrompt, model: request.body.model, max_tokens_to_sample: request.body.max_tokens, - stop_sequences: ["\n\nHuman:", "\n\nSystem:", "\n\nAssistant:"], + stop_sequences: stop_sequences, temperature: request.body.temperature, top_p: request.body.top_p, top_k: request.body.top_k, @@ -3489,6 +3495,11 @@ app.post("/generate_openai", jsonParser, function (request, response_generate_op return response_generate_openai.status(401).send({ error: true }); } + // Add custom stop sequences + if (Array.isArray(request.body.stop)) { + bodyParams['stop'] = request.body.stop; + } + const isTextCompletion = Boolean(request.body.model && (request.body.model.startsWith('text-') || request.body.model.startsWith('code-'))); const textPrompt = isTextCompletion ? convertChatMLPrompt(request.body.messages) : ''; const endpointUrl = isTextCompletion ? `${api_url}/completions` : `${api_url}/chat/completions`;