Stop string for user-continue. Trim spaces after name2

This commit is contained in:
Cohee 2023-11-22 16:16:48 +02:00
parent 55af72cb17
commit 61908935f5
5 changed files with 34 additions and 21 deletions

View File

@ -2156,13 +2156,23 @@ function diceRollReplace(input, invalidRollPlaceholder = '') {
}); });
} }
function getStoppingStrings(isImpersonate) { /**
* Gets stopping sequences for the prompt.
* @param {boolean} isImpersonate A request is made to impersonate a user
* @param {boolean} isContinue A request is made to continue the message
* @returns {string[]} Array of stopping strings
*/
function getStoppingStrings(isImpersonate, isContinue) {
const charString = `\n${name2}:`; const charString = `\n${name2}:`;
const userString = `\n${name1}:`; const userString = `\n${name1}:`;
const result = isImpersonate ? [charString] : [userString]; const result = isImpersonate ? [charString] : [userString];
result.push(userString); result.push(userString);
if (isContinue && Array.isArray(chat) && chat[chat.length - 1]?.is_user) {
result.push(charString);
}
// Add other group members as the stopping strings // Add other group members as the stopping strings
if (selected_group) { if (selected_group) {
const group = groups.find(x => x.id === selected_group); const group = groups.find(x => x.id === selected_group);
@ -2717,10 +2727,10 @@ export async function generateRaw(prompt, api) {
break; break;
case 'novel': case 'novel':
const novelSettings = novelai_settings[novelai_setting_names[nai_settings.preset_settings_novel]]; const novelSettings = novelai_settings[novelai_setting_names[nai_settings.preset_settings_novel]];
generateData = getNovelGenerationData(prompt, novelSettings, amount_gen, false, null); generateData = getNovelGenerationData(prompt, novelSettings, amount_gen, false, false, null);
break; break;
case 'textgenerationwebui': case 'textgenerationwebui':
generateData = getTextGenGenerationData(prompt, amount_gen, false, null); generateData = getTextGenGenerationData(prompt, amount_gen, false, false, null);
break; break;
case 'openai': case 'openai':
generateData = [{ role: 'user', content: prompt.trim() }]; generateData = [{ role: 'user', content: prompt.trim() }];
@ -3521,11 +3531,11 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
} }
} }
else if (main_api == 'textgenerationwebui') { else if (main_api == 'textgenerationwebui') {
generate_data = getTextGenGenerationData(finalPrompt, maxLength, isImpersonate, cfgValues); generate_data = getTextGenGenerationData(finalPrompt, maxLength, isImpersonate, isContinue, cfgValues);
} }
else if (main_api == 'novel') { else if (main_api == 'novel') {
const presetSettings = novelai_settings[novelai_setting_names[nai_settings.preset_settings_novel]]; const presetSettings = novelai_settings[novelai_setting_names[nai_settings.preset_settings_novel]];
generate_data = getNovelGenerationData(finalPrompt, presetSettings, maxLength, isImpersonate, cfgValues); generate_data = getNovelGenerationData(finalPrompt, presetSettings, maxLength, isImpersonate, isContinue, cfgValues);
} }
else if (main_api == 'openai') { else if (main_api == 'openai') {
let [prompt, counts] = await prepareOpenAIMessages({ let [prompt, counts] = await prepareOpenAIMessages({
@ -4328,7 +4338,7 @@ function cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncomplete
getMessage = substituteParams(power_user.user_prompt_bias) + getMessage; getMessage = substituteParams(power_user.user_prompt_bias) + getMessage;
} }
const stoppingStrings = getStoppingStrings(isImpersonate); const stoppingStrings = getStoppingStrings(isImpersonate, isContinue);
for (const stoppingString of stoppingStrings) { for (const stoppingString of stoppingStrings) {
if (stoppingString.length) { if (stoppingString.length) {
@ -4370,13 +4380,13 @@ function cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncomplete
} }
if (nameToTrim && getMessage.indexOf(`${nameToTrim}:`) == 0) { if (nameToTrim && getMessage.indexOf(`${nameToTrim}:`) == 0) {
getMessage = getMessage.substr(0, getMessage.indexOf(`${nameToTrim}:`)); getMessage = getMessage.substring(0, getMessage.indexOf(`${nameToTrim}:`));
} }
if (nameToTrim && getMessage.indexOf(`\n${nameToTrim}:`) >= 0) { if (nameToTrim && getMessage.indexOf(`\n${nameToTrim}:`) >= 0) {
getMessage = getMessage.substr(0, getMessage.indexOf(`\n${nameToTrim}:`)); getMessage = getMessage.substring(0, getMessage.indexOf(`\n${nameToTrim}:`));
} }
if (getMessage.indexOf('<|endoftext|>') != -1) { if (getMessage.indexOf('<|endoftext|>') != -1) {
getMessage = getMessage.substr(0, getMessage.indexOf('<|endoftext|>')); getMessage = getMessage.substring(0, getMessage.indexOf('<|endoftext|>'));
} }
const isInstruct = power_user.instruct.enabled && main_api !== 'openai'; const isInstruct = power_user.instruct.enabled && main_api !== 'openai';
if (isInstruct && power_user.instruct.stop_sequence) { if (isInstruct && power_user.instruct.stop_sequence) {
@ -4421,7 +4431,8 @@ function cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncomplete
} }
if (!power_user.allow_name2_display) { if (!power_user.allow_name2_display) {
getMessage = getMessage.replace(new RegExp(`(^|\n)${name2}:`, 'g'), "$1"); const name2Escaped = escapeRegex(name2);
getMessage = getMessage.replace(new RegExp(`(^|\n)${name2Escaped}:\\s*`, 'g'), "$1");
} }
if (isImpersonate) { if (isImpersonate) {

View File

@ -105,6 +105,7 @@ export function loadKoboldSettings(preset) {
*/ */
export function getKoboldGenerationData(finalPrompt, settings, maxLength, maxContextLength, isHorde, type) { export function getKoboldGenerationData(finalPrompt, settings, maxLength, maxContextLength, isHorde, type) {
const isImpersonate = type === 'impersonate'; const isImpersonate = type === 'impersonate';
const isContinue = type === 'continue';
const sampler_order = kai_settings.sampler_order || settings.sampler_order; const sampler_order = kai_settings.sampler_order || settings.sampler_order;
let generate_data = { let generate_data = {
@ -132,7 +133,7 @@ export function getKoboldGenerationData(finalPrompt, settings, maxLength, maxCon
s7: sampler_order[6], s7: sampler_order[6],
use_world_info: false, use_world_info: false,
singleline: false, singleline: false,
stop_sequence: (kai_flags.can_use_stop_sequence || isHorde) ? getStoppingStrings(isImpersonate) : undefined, stop_sequence: (kai_flags.can_use_stop_sequence || isHorde) ? getStoppingStrings(isImpersonate, isContinue) : undefined,
streaming: kai_settings.streaming_kobold && kai_flags.can_use_streaming && type !== 'quiet', streaming: kai_settings.streaming_kobold && kai_flags.can_use_streaming && type !== 'quiet',
can_abort: kai_flags.can_use_streaming, can_abort: kai_flags.can_use_streaming,
mirostat: (kai_flags.can_use_mirostat || isHorde) ? kai_settings.mirostat : undefined, mirostat: (kai_flags.can_use_mirostat || isHorde) ? kai_settings.mirostat : undefined,

View File

@ -409,7 +409,7 @@ function getBadWordPermutations(text) {
return result.filter(onlyUnique); return result.filter(onlyUnique);
} }
export function getNovelGenerationData(finalPrompt, this_settings, this_amount_gen, isImpersonate, cfgValues) { export function getNovelGenerationData(finalPrompt, settings, maxLength, isImpersonate, isContinue, cfgValues) {
if (cfgValues && cfgValues.guidanceScale && cfgValues.guidanceScale?.value !== 1) { if (cfgValues && cfgValues.guidanceScale && cfgValues.guidanceScale?.value !== 1) {
cfgValues.negativePrompt = (getCfgPrompt(cfgValues.guidanceScale, true))?.value; cfgValues.negativePrompt = (getCfgPrompt(cfgValues.guidanceScale, true))?.value;
} }
@ -419,7 +419,7 @@ export function getNovelGenerationData(finalPrompt, this_settings, this_amount_g
const tokenizerType = kayra ? tokenizers.NERD2 : (clio ? tokenizers.NERD : tokenizers.NONE); const tokenizerType = kayra ? tokenizers.NERD2 : (clio ? tokenizers.NERD : tokenizers.NONE);
const stopSequences = (tokenizerType !== tokenizers.NONE) const stopSequences = (tokenizerType !== tokenizers.NONE)
? getStoppingStrings(isImpersonate) ? getStoppingStrings(isImpersonate, isContinue)
.map(t => getTextTokens(tokenizerType, t)) .map(t => getTextTokens(tokenizerType, t))
: undefined; : undefined;
@ -440,7 +440,7 @@ export function getNovelGenerationData(finalPrompt, this_settings, this_amount_g
"model": nai_settings.model_novel, "model": nai_settings.model_novel,
"use_string": true, "use_string": true,
"temperature": Number(nai_settings.temperature), "temperature": Number(nai_settings.temperature),
"max_length": this_amount_gen < maximum_output_length ? this_amount_gen : maximum_output_length, "max_length": maxLength < maximum_output_length ? maxLength : maximum_output_length,
"min_length": Number(nai_settings.min_length), "min_length": Number(nai_settings.min_length),
"tail_free_sampling": Number(nai_settings.tail_free_sampling), "tail_free_sampling": Number(nai_settings.tail_free_sampling),
"repetition_penalty": Number(nai_settings.repetition_penalty), "repetition_penalty": Number(nai_settings.repetition_penalty),
@ -464,7 +464,7 @@ export function getNovelGenerationData(finalPrompt, this_settings, this_amount_g
"use_cache": false, "use_cache": false,
"return_full_text": false, "return_full_text": false,
"prefix": prefix, "prefix": prefix,
"order": nai_settings.order || this_settings.order || default_order, "order": nai_settings.order || settings.order || default_order,
}; };
} }

View File

@ -1440,6 +1440,7 @@ async function sendOpenAIRequest(type, messages, signal) {
const isTextCompletion = (isOAI && textCompletionModels.includes(oai_settings.openai_model)) || (isOpenRouter && oai_settings.openrouter_force_instruct && power_user.instruct.enabled); const isTextCompletion = (isOAI && textCompletionModels.includes(oai_settings.openai_model)) || (isOpenRouter && oai_settings.openrouter_force_instruct && power_user.instruct.enabled);
const isQuiet = type === 'quiet'; const isQuiet = type === 'quiet';
const isImpersonate = type === 'impersonate'; const isImpersonate = type === 'impersonate';
const isContinue = type === 'continue';
const stream = oai_settings.stream_openai && !isQuiet && !isScale && !isAI21 && !isPalm; const stream = oai_settings.stream_openai && !isQuiet && !isScale && !isAI21 && !isPalm;
if (isTextCompletion && isOpenRouter) { if (isTextCompletion && isOpenRouter) {
@ -1523,7 +1524,7 @@ async function sendOpenAIRequest(type, messages, signal) {
generate_data['use_fallback'] = oai_settings.openrouter_use_fallback; generate_data['use_fallback'] = oai_settings.openrouter_use_fallback;
if (isTextCompletion) { if (isTextCompletion) {
generate_data['stop'] = getStoppingStrings(isImpersonate); generate_data['stop'] = getStoppingStrings(isImpersonate, isContinue);
} }
} }

View File

@ -575,12 +575,12 @@ function getModel() {
return undefined; return undefined;
} }
export function getTextGenGenerationData(finalPrompt, this_amount_gen, isImpersonate, cfgValues) { export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate, isContinue, cfgValues) {
let APIflags = { let APIflags = {
'prompt': finalPrompt, 'prompt': finalPrompt,
'model': getModel(), 'model': getModel(),
'max_new_tokens': this_amount_gen, 'max_new_tokens': maxTokens,
'max_tokens': this_amount_gen, 'max_tokens': maxTokens,
'temperature': textgenerationwebui_settings.temp, 'temperature': textgenerationwebui_settings.temp,
'top_p': textgenerationwebui_settings.top_p, 'top_p': textgenerationwebui_settings.top_p,
'typical_p': textgenerationwebui_settings.typical_p, 'typical_p': textgenerationwebui_settings.typical_p,
@ -595,8 +595,8 @@ export function getTextGenGenerationData(finalPrompt, this_amount_gen, isImperso
'length_penalty': textgenerationwebui_settings.length_penalty, 'length_penalty': textgenerationwebui_settings.length_penalty,
'early_stopping': textgenerationwebui_settings.early_stopping, 'early_stopping': textgenerationwebui_settings.early_stopping,
'add_bos_token': textgenerationwebui_settings.add_bos_token, 'add_bos_token': textgenerationwebui_settings.add_bos_token,
'stopping_strings': getStoppingStrings(isImpersonate), 'stopping_strings': getStoppingStrings(isImpersonate, isContinue),
'stop': getStoppingStrings(isImpersonate), 'stop': getStoppingStrings(isImpersonate, isContinue),
'truncation_length': max_context, 'truncation_length': max_context,
'ban_eos_token': textgenerationwebui_settings.ban_eos_token, 'ban_eos_token': textgenerationwebui_settings.ban_eos_token,
'skip_special_tokens': textgenerationwebui_settings.skip_special_tokens, 'skip_special_tokens': textgenerationwebui_settings.skip_special_tokens,