Add ability to temporarily override response length for /gen and /genraw

This commit is contained in:
Cohee 2024-03-31 21:02:38 +03:00
parent ad4269f476
commit 3331cb6491
2 changed files with 138 additions and 75 deletions

View File

@ -2372,21 +2372,31 @@ function getStoppingStrings(isImpersonate, isContinue) {
* @param {boolean} skipWIAN whether to skip addition of World Info and Author's Note into the prompt
* @param {string} quietImage Image to use for the quiet prompt
* @param {string} quietName Name to use for the quiet prompt (defaults to "System:")
* @param {number} [responseLength] Maximum response length. If unset, the global default value is used.
* @returns
*/
export async function generateQuietPrompt(quiet_prompt, quietToLoud, skipWIAN, quietImage = null, quietName = null) {
export async function generateQuietPrompt(quiet_prompt, quietToLoud, skipWIAN, quietImage = null, quietName = null, responseLength = null) {
console.log('got into genQuietPrompt');
/** @type {GenerateOptions} */
const options = {
quiet_prompt,
quietToLoud,
skipWIAN: skipWIAN,
force_name2: true,
quietImage: quietImage,
quietName: quietName,
};
const generateFinished = await Generate('quiet', options);
return generateFinished;
const responseLengthCustomized = typeof responseLength === 'number' && responseLength > 0;
let originalResponseLength = -1;
try {
/** @type {GenerateOptions} */
const options = {
quiet_prompt,
quietToLoud,
skipWIAN: skipWIAN,
force_name2: true,
quietImage: quietImage,
quietName: quietName,
};
originalResponseLength = responseLengthCustomized ? saveResponseLength(main_api, responseLength) : -1;
const generateFinished = await Generate('quiet', options);
return generateFinished;
} finally {
if (responseLengthCustomized) {
restoreResponseLength(main_api, originalResponseLength);
}
}
}
/**
@ -2912,14 +2922,17 @@ class StreamingProcessor {
* @param {boolean} instructOverride true to override instruct mode, false to use the default value
* @param {boolean} quietToLoud true to generate a message in system mode, false to generate a message in character mode
* @param {string} [systemPrompt] System prompt to use. Only Instruct mode or OpenAI.
* @param {number} [responseLength] Maximum response length. If unset, the global default value is used.
* @returns {Promise<string>} Generated message
*/
export async function generateRaw(prompt, api, instructOverride, quietToLoud, systemPrompt) {
export async function generateRaw(prompt, api, instructOverride, quietToLoud, systemPrompt, responseLength) {
if (!api) {
api = main_api;
}
const abortController = new AbortController();
const responseLengthCustomized = typeof responseLength === 'number' && responseLength > 0;
let originalResponseLength = -1;
const isInstruct = power_user.instruct.enabled && api !== 'openai' && api !== 'novel' && !instructOverride;
const isQuiet = true;
@ -2934,70 +2947,109 @@ export async function generateRaw(prompt, api, instructOverride, quietToLoud, sy
prompt = isInstruct ? formatInstructModeChat(name1, prompt, false, true, '', name1, name2, false) : prompt;
prompt = isInstruct ? (prompt + formatInstructModePrompt(name2, false, '', name1, name2, isQuiet, quietToLoud)) : (prompt + '\n');
let generateData = {};
try {
originalResponseLength = responseLengthCustomized ? saveResponseLength(api, responseLength) : -1;
let generateData = {};
switch (api) {
case 'kobold':
case 'koboldhorde':
if (preset_settings === 'gui') {
generateData = { prompt: prompt, gui_settings: true, max_length: amount_gen, max_context_length: max_context, api_server };
} else {
const isHorde = api === 'koboldhorde';
const koboldSettings = koboldai_settings[koboldai_setting_names[preset_settings]];
generateData = getKoboldGenerationData(prompt, koboldSettings, amount_gen, max_context, isHorde, 'quiet');
switch (api) {
case 'kobold':
case 'koboldhorde':
if (preset_settings === 'gui') {
generateData = { prompt: prompt, gui_settings: true, max_length: amount_gen, max_context_length: max_context, api_server };
} else {
const isHorde = api === 'koboldhorde';
const koboldSettings = koboldai_settings[koboldai_setting_names[preset_settings]];
generateData = getKoboldGenerationData(prompt, koboldSettings, amount_gen, max_context, isHorde, 'quiet');
}
break;
case 'novel': {
const novelSettings = novelai_settings[novelai_setting_names[nai_settings.preset_settings_novel]];
generateData = getNovelGenerationData(prompt, novelSettings, amount_gen, false, false, null, 'quiet');
break;
}
break;
case 'novel': {
const novelSettings = novelai_settings[novelai_setting_names[nai_settings.preset_settings_novel]];
generateData = getNovelGenerationData(prompt, novelSettings, amount_gen, false, false, null, 'quiet');
break;
case 'textgenerationwebui':
generateData = getTextGenGenerationData(prompt, amount_gen, false, false, null, 'quiet');
break;
case 'openai': {
generateData = [{ role: 'user', content: prompt.trim() }];
if (systemPrompt) {
generateData.unshift({ role: 'system', content: systemPrompt.trim() });
}
} break;
}
case 'textgenerationwebui':
generateData = getTextGenGenerationData(prompt, amount_gen, false, false, null, 'quiet');
break;
case 'openai': {
generateData = [{ role: 'user', content: prompt.trim() }];
if (systemPrompt) {
generateData.unshift({ role: 'system', content: systemPrompt.trim() });
let data = {};
if (api == 'koboldhorde') {
data = await generateHorde(prompt, generateData, abortController.signal, false);
} else if (api == 'openai') {
data = await sendOpenAIRequest('quiet', generateData, abortController.signal);
} else {
const generateUrl = getGenerateUrl(api);
const response = await fetch(generateUrl, {
method: 'POST',
headers: getRequestHeaders(),
cache: 'no-cache',
body: JSON.stringify(generateData),
signal: abortController.signal,
});
if (!response.ok) {
const error = await response.json();
throw error;
}
} break;
data = await response.json();
}
if (data.error) {
throw new Error(data.error);
}
const message = cleanUpMessage(extractMessageFromData(data), false, false, true);
if (!message) {
throw new Error('No message generated');
}
return message;
} finally {
if (responseLengthCustomized) {
restoreResponseLength(api, originalResponseLength);
}
}
}
let data = {};
if (api == 'koboldhorde') {
data = await generateHorde(prompt, generateData, abortController.signal, false);
} else if (api == 'openai') {
data = await sendOpenAIRequest('quiet', generateData, abortController.signal);
/**
* Temporarily change the response length for the specified API.
* @param {string} api API to use.
* @param {number} responseLength Target response length.
* @returns {number} The original response length.
*/
function saveResponseLength(api, responseLength) {
let oldValue = -1;
if (api === 'openai') {
oldValue = oai_settings.openai_max_tokens;
oai_settings.openai_max_tokens = responseLength;
} else {
const generateUrl = getGenerateUrl(api);
const response = await fetch(generateUrl, {
method: 'POST',
headers: getRequestHeaders(),
cache: 'no-cache',
body: JSON.stringify(generateData),
signal: abortController.signal,
});
if (!response.ok) {
const error = await response.json();
throw error;
}
data = await response.json();
oldValue = max_context;
max_context = responseLength;
}
return oldValue;
}
if (data.error) {
throw new Error(data.error);
/**
* Restore the original response length for the specified API.
* @param {string} api API to use.
* @param {number} responseLength Target response length.
* @returns {void}
*/
function restoreResponseLength(api, responseLength) {
if (api === 'openai') {
oai_settings.openai_max_tokens = responseLength;
} else {
max_context = responseLength;
}
const message = cleanUpMessage(extractMessageFromData(data), false, false, true);
if (!message) {
throw new Error('No message generated');
}
return message;
}
/**
@ -4390,10 +4442,19 @@ export async function sendMessageAsUser(messageText, messageBias, insertAt = nul
}
}
export function getMaxContextSize() {
/**
* Gets the maximum usable context size for the current API.
* @param {number|null} overrideResponseLength Optional override for the response length.
* @returns {number} Maximum usable context size.
*/
export function getMaxContextSize(overrideResponseLength = null) {
if (typeof overrideResponseLength !== 'number' || overrideResponseLength <= 0 || isNaN(overrideResponseLength)) {
overrideResponseLength = null;
}
let this_max_context = 1487;
if (main_api == 'kobold' || main_api == 'koboldhorde' || main_api == 'textgenerationwebui') {
this_max_context = (max_context - amount_gen);
this_max_context = (max_context - (overrideResponseLength || amount_gen));
}
if (main_api == 'novel') {
this_max_context = Number(max_context);
@ -4410,10 +4471,10 @@ export function getMaxContextSize() {
}
}
this_max_context = this_max_context - amount_gen;
this_max_context = this_max_context - (overrideResponseLength || amount_gen);
}
if (main_api == 'openai') {
this_max_context = oai_settings.openai_max_context - oai_settings.openai_max_tokens;
this_max_context = oai_settings.openai_max_context - (overrideResponseLength || oai_settings.openai_max_tokens);
}
return this_max_context;
}

View File

@ -230,8 +230,8 @@ parser.addCommand('peek', peekCallback, [], '<span class="monospace">(message in
parser.addCommand('delswipe', deleteSwipeCallback, ['swipedel'], '<span class="monospace">(optional 1-based id)</span> deletes a swipe from the last chat message. If swipe id not provided - deletes the current swipe.', true, true);
parser.addCommand('echo', echoCallback, [], '<span class="monospace">(title=string severity=info/warning/error/success [text])</span> echoes the text to toast message. Useful for pipes debugging.', true, true);
//parser.addCommand('#', (_, value) => '', [], ' a comment, does nothing, e.g. <tt>/# the next three commands switch variables a and b</tt>', true, true);
parser.addCommand('gen', generateCallback, [], '<span class="monospace">(lock=on/off name="System" [prompt])</span> generates text using the provided prompt and passes it to the next command through the pipe, optionally locking user input while generating and allowing to configure the in-prompt name for instruct mode (default = "System"). "as" argument controls the role of the output prompt: system (default) or char.', true, true);
parser.addCommand('genraw', generateRawCallback, [], '<span class="monospace">(lock=on/off instruct=on/off stop=[] as=system/char system="system prompt" [prompt])</span> generates text using the provided prompt and passes it to the next command through the pipe, optionally locking user input while generating. Does not include chat history or character card. Use instruct=off to skip instruct formatting, e.g. <tt>/genraw instruct=off Why is the sky blue?</tt>. Use stop=... with a JSON-serialized array to add one-time custom stop strings, e.g. <tt>/genraw stop=["\\n"] Say hi</tt>. "as" argument controls the role of the output prompt: system (default) or char. "system" argument adds an (optional) system prompt at the start.', true, true);
parser.addCommand('gen', generateCallback, [], '<span class="monospace">(lock=on/off name="System" length=123 [prompt])</span> generates text using the provided prompt and passes it to the next command through the pipe, optionally locking user input while generating and allowing to configure the in-prompt name for instruct mode (default = "System"). "as" argument controls the role of the output prompt: system (default) or char. If "length" argument is provided as a number in tokens, allows to temporarily override an API response length.', true, true);
parser.addCommand('genraw', generateRawCallback, [], '<span class="monospace">(lock=on/off instruct=on/off stop=[] as=system/char system="system prompt" length=123 [prompt])</span> generates text using the provided prompt and passes it to the next command through the pipe, optionally locking user input while generating. Does not include chat history or character card. Use instruct=off to skip instruct formatting, e.g. <tt>/genraw instruct=off Why is the sky blue?</tt>. Use stop=... with a JSON-serialized array to add one-time custom stop strings, e.g. <tt>/genraw stop=["\\n"] Say hi</tt>. "as" argument controls the role of the output prompt: system (default) or char. "system" argument adds an (optional) system prompt at the start. If "length" argument is provided as a number in tokens, allows to temporarily override an API response length.', true, true);
parser.addCommand('addswipe', addSwipeCallback, ['swipeadd'], '<span class="monospace">(text)</span> adds a swipe to the last chat message.', true, true);
parser.addCommand('abort', abortCallback, [], ' aborts the slash command batch execution', true, true);
parser.addCommand('fuzzy', fuzzyCallback, [], 'list=["a","b","c"] threshold=0.4 (text to search) performs a fuzzy match of each items of list within the text to search. If any item matches then its name is returned. If no item list matches the text to search then no value is returned. The optional threshold (default is 0.4) allows some control over the matching. A low value (min 0.0) means the match is very strict. At 1.0 (max) the match is very loose and probably matches anything. The returned value passes to the next command through the pipe.', true, true); parser.addCommand('pass', (_, arg) => arg, ['return'], '<span class="monospace">(text)</span> passes the text to the next command through the pipe.', true, true);
@ -662,6 +662,7 @@ async function generateRawCallback(args, value) {
const as = args?.as || 'system';
const quietToLoud = as === 'char';
const systemPrompt = resolveVariable(args?.system) || '';
const length = Number(resolveVariable(args?.length) ?? 0) || 0;
try {
if (lock) {
@ -669,7 +670,7 @@ async function generateRawCallback(args, value) {
}
setEphemeralStopStrings(resolveVariable(args?.stop));
const result = await generateRaw(value, '', isFalseBoolean(args?.instruct), quietToLoud, systemPrompt);
const result = await generateRaw(value, '', isFalseBoolean(args?.instruct), quietToLoud, systemPrompt, length);
return result;
} finally {
if (lock) {
@ -690,6 +691,7 @@ async function generateCallback(args, value) {
const lock = isTrueBoolean(args?.lock);
const as = args?.as || 'system';
const quietToLoud = as === 'char';
const length = Number(resolveVariable(args?.length) ?? 0) || 0;
try {
if (lock) {
@ -698,7 +700,7 @@ async function generateCallback(args, value) {
setEphemeralStopStrings(resolveVariable(args?.stop));
const name = args?.name;
const result = await generateQuietPrompt(value, quietToLoud, false, '', name);
const result = await generateQuietPrompt(value, quietToLoud, false, '', name, length);
return result;
} finally {
if (lock) {