Restore temp response length as early as possible

Fixes #3297
This commit is contained in:
Cohee
2025-01-13 21:20:26 +02:00
parent 4322197aba
commit 749c5df29a

View File

@ -2724,7 +2724,7 @@ export function getStoppingStrings(isImpersonate, isContinue) {
export async function generateQuietPrompt(quiet_prompt, quietToLoud, skipWIAN, quietImage = null, quietName = null, responseLength = null, force_chid = null) { export async function generateQuietPrompt(quiet_prompt, quietToLoud, skipWIAN, quietImage = null, quietName = null, responseLength = null, force_chid = null) {
console.log('got into genQuietPrompt'); console.log('got into genQuietPrompt');
const responseLengthCustomized = typeof responseLength === 'number' && responseLength > 0; const responseLengthCustomized = typeof responseLength === 'number' && responseLength > 0;
let originalResponseLength = -1; let eventHook = () => {};
try { try {
/** @type {GenerateOptions} */ /** @type {GenerateOptions} */
const options = { const options = {
@ -2736,11 +2736,15 @@ export async function generateQuietPrompt(quiet_prompt, quietToLoud, skipWIAN, q
quietName: quietName, quietName: quietName,
force_chid: force_chid, force_chid: force_chid,
}; };
originalResponseLength = responseLengthCustomized ? saveResponseLength(main_api, responseLength) : -1; if (responseLengthCustomized) {
TempResponseLength.save(main_api, responseLength);
eventHook = TempResponseLength.setupEventHook(main_api);
}
return await Generate('quiet', options); return await Generate('quiet', options);
} finally { } finally {
if (responseLengthCustomized) { if (responseLengthCustomized && TempResponseLength.isCustomized()) {
restoreResponseLength(main_api, originalResponseLength); TempResponseLength.restore(main_api);
TempResponseLength.removeEventHook(main_api, eventHook);
} }
} }
} }
@ -3384,9 +3388,9 @@ export async function generateRaw(prompt, api, instructOverride, quietToLoud, sy
const abortController = new AbortController(); const abortController = new AbortController();
const responseLengthCustomized = typeof responseLength === 'number' && responseLength > 0; const responseLengthCustomized = typeof responseLength === 'number' && responseLength > 0;
let originalResponseLength = -1;
const isInstruct = power_user.instruct.enabled && api !== 'openai' && api !== 'novel' && !instructOverride; const isInstruct = power_user.instruct.enabled && api !== 'openai' && api !== 'novel' && !instructOverride;
const isQuiet = true; const isQuiet = true;
let eventHook = () => {};
if (systemPrompt) { if (systemPrompt) {
systemPrompt = substituteParams(systemPrompt); systemPrompt = substituteParams(systemPrompt);
@ -3400,7 +3404,9 @@ export async function generateRaw(prompt, api, instructOverride, quietToLoud, sy
prompt = isInstruct ? (prompt + formatInstructModePrompt(name2, false, '', name1, name2, isQuiet, quietToLoud)) : (prompt + '\n'); prompt = isInstruct ? (prompt + formatInstructModePrompt(name2, false, '', name1, name2, isQuiet, quietToLoud)) : (prompt + '\n');
try { try {
originalResponseLength = responseLengthCustomized ? saveResponseLength(api, responseLength) : -1; if (responseLengthCustomized) {
TempResponseLength.save(api, responseLength);
}
let generateData = {}; let generateData = {};
switch (api) { switch (api) {
@ -3413,20 +3419,24 @@ export async function generateRaw(prompt, api, instructOverride, quietToLoud, sy
const koboldSettings = koboldai_settings[koboldai_setting_names[preset_settings]]; const koboldSettings = koboldai_settings[koboldai_setting_names[preset_settings]];
generateData = getKoboldGenerationData(prompt, koboldSettings, amount_gen, max_context, isHorde, 'quiet'); generateData = getKoboldGenerationData(prompt, koboldSettings, amount_gen, max_context, isHorde, 'quiet');
} }
TempResponseLength.restore(api);
break; break;
case 'novel': { case 'novel': {
const novelSettings = novelai_settings[novelai_setting_names[nai_settings.preset_settings_novel]]; const novelSettings = novelai_settings[novelai_setting_names[nai_settings.preset_settings_novel]];
generateData = getNovelGenerationData(prompt, novelSettings, amount_gen, false, false, null, 'quiet'); generateData = getNovelGenerationData(prompt, novelSettings, amount_gen, false, false, null, 'quiet');
TempResponseLength.restore(api);
break; break;
} }
case 'textgenerationwebui': case 'textgenerationwebui':
generateData = getTextGenGenerationData(prompt, amount_gen, false, false, null, 'quiet'); generateData = getTextGenGenerationData(prompt, amount_gen, false, false, null, 'quiet');
TempResponseLength.restore(api);
break; break;
case 'openai': { case 'openai': {
generateData = [{ role: 'user', content: prompt.trim() }]; generateData = [{ role: 'user', content: prompt.trim() }];
if (systemPrompt) { if (systemPrompt) {
generateData.unshift({ role: 'system', content: systemPrompt.trim() }); generateData.unshift({ role: 'system', content: systemPrompt.trim() });
} }
eventHook = TempResponseLength.setupEventHook(api);
} break; } break;
} }
@ -3468,41 +3478,100 @@ export async function generateRaw(prompt, api, instructOverride, quietToLoud, sy
return message; return message;
} finally { } finally {
if (responseLengthCustomized) { if (responseLengthCustomized && TempResponseLength.isCustomized()) {
restoreResponseLength(api, originalResponseLength); TempResponseLength.restore(api);
TempResponseLength.removeEventHook(api, eventHook);
} }
} }
} }
/** class TempResponseLength {
* Temporarily change the response length for the specified API. static #originalResponseLength = -1;
* @param {string} api API to use. static #lastApi = null;
* @param {number} responseLength Target response length.
* @returns {number} The original response length. static isCustomized() {
return this.#originalResponseLength > -1;
}
/**
* Save the current response length for the specified API.
* @param {string} api API identifier
* @param {number} responseLength New response length
*/ */
function saveResponseLength(api, responseLength) { static save(api, responseLength) {
let oldValue = -1;
if (api === 'openai') { if (api === 'openai') {
oldValue = oai_settings.openai_max_tokens; this.#originalResponseLength = oai_settings.openai_max_tokens;
oai_settings.openai_max_tokens = responseLength; oai_settings.openai_max_tokens = responseLength;
} else { } else {
oldValue = amount_gen; this.#originalResponseLength = amount_gen;
amount_gen = responseLength; amount_gen = responseLength;
} }
return oldValue;
}
/** this.#lastApi = api;
console.log('[TempResponseLength] Saved original response length:', TempResponseLength.#originalResponseLength);
}
/**
* Restore the original response length for the specified API. * Restore the original response length for the specified API.
* @param {string} api API to use. * @param {string|null} api API identifier
* @param {number} responseLength Target response length.
* @returns {void} * @returns {void}
*/ */
function restoreResponseLength(api, responseLength) { static restore(api) {
if (this.#originalResponseLength === -1) {
return;
}
if (!api && this.#lastApi) {
api = this.#lastApi;
}
if (api === 'openai') { if (api === 'openai') {
oai_settings.openai_max_tokens = responseLength; oai_settings.openai_max_tokens = this.#originalResponseLength;
} else { } else {
amount_gen = responseLength; amount_gen = this.#originalResponseLength;
}
console.log('[TempResponseLength] Restored original response length:', this.#originalResponseLength);
this.#originalResponseLength = -1;
this.#lastApi = null;
}
/**
* Sets up an event hook to restore the original response length when the event is emitted.
* @param {string} api API identifier
* @returns {function(): void} Event hook function
*/
static setupEventHook(api) {
const eventHook = () => {
if (this.isCustomized()) {
this.restore(api);
}
};
switch (api) {
case 'openai':
eventSource.once(event_types.CHAT_COMPLETION_SETTINGS_READY, eventHook);
break;
default:
eventSource.once(event_types.GENERATE_AFTER_DATA, eventHook);
break;
}
return eventHook;
}
/**
* Removes the event hook for the specified API.
* @param {string} api API identifier
* @param {function(): void} eventHook Previously set up event hook
*/
static removeEventHook(api, eventHook) {
switch (api) {
case 'openai':
eventSource.removeListener(event_types.CHAT_COMPLETION_SETTINGS_READY, eventHook);
break;
default:
eventSource.removeListener(event_types.GENERATE_AFTER_DATA, eventHook);
break;
}
} }
} }
@ -6871,6 +6940,11 @@ export async function saveSettings(type) {
return; return;
} }
if (TempResponseLength.isCustomized()) {
console.warn('Response length is currently being overridden. Restoring previous value before saving.');
TempResponseLength.restore(null);
}
//console.log('Entering settings with name1 = '+name1); //console.log('Entering settings with name1 = '+name1);
return jQuery.ajax({ return jQuery.ajax({
type: 'POST', type: 'POST',