diff --git a/public/script.js b/public/script.js index e28f51624..f89d3c281 100644 --- a/public/script.js +++ b/public/script.js @@ -2724,7 +2724,7 @@ export function getStoppingStrings(isImpersonate, isContinue) { export async function generateQuietPrompt(quiet_prompt, quietToLoud, skipWIAN, quietImage = null, quietName = null, responseLength = null, force_chid = null) { console.log('got into genQuietPrompt'); const responseLengthCustomized = typeof responseLength === 'number' && responseLength > 0; - let originalResponseLength = -1; + let eventHook = () => {}; try { /** @type {GenerateOptions} */ const options = { @@ -2736,11 +2736,15 @@ export async function generateQuietPrompt(quiet_prompt, quietToLoud, skipWIAN, q quietName: quietName, force_chid: force_chid, }; - originalResponseLength = responseLengthCustomized ? saveResponseLength(main_api, responseLength) : -1; + if (responseLengthCustomized) { + TempResponseLength.save(main_api, responseLength); + eventHook = TempResponseLength.setupEventHook(main_api); + } return await Generate('quiet', options); } finally { - if (responseLengthCustomized) { - restoreResponseLength(main_api, originalResponseLength); + if (responseLengthCustomized && TempResponseLength.isCustomized()) { + TempResponseLength.restore(main_api); + TempResponseLength.removeEventHook(main_api, eventHook); } } } @@ -3384,9 +3388,9 @@ export async function generateRaw(prompt, api, instructOverride, quietToLoud, sy const abortController = new AbortController(); const responseLengthCustomized = typeof responseLength === 'number' && responseLength > 0; - let originalResponseLength = -1; const isInstruct = power_user.instruct.enabled && api !== 'openai' && api !== 'novel' && !instructOverride; const isQuiet = true; + let eventHook = () => {}; if (systemPrompt) { systemPrompt = substituteParams(systemPrompt); @@ -3400,7 +3404,9 @@ export async function generateRaw(prompt, api, instructOverride, quietToLoud, sy prompt = isInstruct ? (prompt + formatInstructModePrompt(name2, false, '', name1, name2, isQuiet, quietToLoud)) : (prompt + '\n'); try { - originalResponseLength = responseLengthCustomized ? saveResponseLength(api, responseLength) : -1; + if (responseLengthCustomized) { + TempResponseLength.save(api, responseLength); + } let generateData = {}; switch (api) { @@ -3413,20 +3419,24 @@ export async function generateRaw(prompt, api, instructOverride, quietToLoud, sy const koboldSettings = koboldai_settings[koboldai_setting_names[preset_settings]]; generateData = getKoboldGenerationData(prompt, koboldSettings, amount_gen, max_context, isHorde, 'quiet'); } + TempResponseLength.restore(api); break; case 'novel': { const novelSettings = novelai_settings[novelai_setting_names[nai_settings.preset_settings_novel]]; generateData = getNovelGenerationData(prompt, novelSettings, amount_gen, false, false, null, 'quiet'); + TempResponseLength.restore(api); break; } case 'textgenerationwebui': generateData = getTextGenGenerationData(prompt, amount_gen, false, false, null, 'quiet'); + TempResponseLength.restore(api); break; case 'openai': { generateData = [{ role: 'user', content: prompt.trim() }]; if (systemPrompt) { generateData.unshift({ role: 'system', content: systemPrompt.trim() }); } + eventHook = TempResponseLength.setupEventHook(api); } break; } @@ -3468,41 +3478,100 @@ export async function generateRaw(prompt, api, instructOverride, quietToLoud, sy return message; } finally { - if (responseLengthCustomized) { - restoreResponseLength(api, originalResponseLength); + if (responseLengthCustomized && TempResponseLength.isCustomized()) { + TempResponseLength.restore(api); + TempResponseLength.removeEventHook(api, eventHook); } } } -/** - * Temporarily change the response length for the specified API. - * @param {string} api API to use. - * @param {number} responseLength Target response length. - * @returns {number} The original response length. - */ -function saveResponseLength(api, responseLength) { - let oldValue = -1; - if (api === 'openai') { - oldValue = oai_settings.openai_max_tokens; - oai_settings.openai_max_tokens = responseLength; - } else { - oldValue = amount_gen; - amount_gen = responseLength; - } - return oldValue; -} +class TempResponseLength { + static #originalResponseLength = -1; + static #lastApi = null; -/** - * Restore the original response length for the specified API. - * @param {string} api API to use. - * @param {number} responseLength Target response length. - * @returns {void} - */ -function restoreResponseLength(api, responseLength) { - if (api === 'openai') { - oai_settings.openai_max_tokens = responseLength; - } else { - amount_gen = responseLength; + static isCustomized() { + return this.#originalResponseLength > -1; + } + + /** + * Save the current response length for the specified API. + * @param {string} api API identifier + * @param {number} responseLength New response length + */ + static save(api, responseLength) { + if (api === 'openai') { + this.#originalResponseLength = oai_settings.openai_max_tokens; + oai_settings.openai_max_tokens = responseLength; + } else { + this.#originalResponseLength = amount_gen; + amount_gen = responseLength; + } + + this.#lastApi = api; + console.log('[TempResponseLength] Saved original response length:', TempResponseLength.#originalResponseLength); + } + + /** + * Restore the original response length for the specified API. + * @param {string|null} api API identifier + * @returns {void} + */ + static restore(api) { + if (this.#originalResponseLength === -1) { + return; + } + if (!api && this.#lastApi) { + api = this.#lastApi; + } + if (api === 'openai') { + oai_settings.openai_max_tokens = this.#originalResponseLength; + } else { + amount_gen = this.#originalResponseLength; + } + + console.log('[TempResponseLength] Restored original response length:', this.#originalResponseLength); + this.#originalResponseLength = -1; + this.#lastApi = null; + } + + /** + * Sets up an event hook to restore the original response length when the event is emitted. + * @param {string} api API identifier + * @returns {function(): void} Event hook function + */ + static setupEventHook(api) { + const eventHook = () => { + if (this.isCustomized()) { + this.restore(api); + } + }; + + switch (api) { + case 'openai': + eventSource.once(event_types.CHAT_COMPLETION_SETTINGS_READY, eventHook); + break; + default: + eventSource.once(event_types.GENERATE_AFTER_DATA, eventHook); + break; + } + + return eventHook; + } + + /** + * Removes the event hook for the specified API. + * @param {string} api API identifier + * @param {function(): void} eventHook Previously set up event hook + */ + static removeEventHook(api, eventHook) { + switch (api) { + case 'openai': + eventSource.removeListener(event_types.CHAT_COMPLETION_SETTINGS_READY, eventHook); + break; + default: + eventSource.removeListener(event_types.GENERATE_AFTER_DATA, eventHook); + break; + } } } @@ -6871,6 +6940,11 @@ export async function saveSettings(type) { return; } + if (TempResponseLength.isCustomized()) { + console.warn('Response length is currently being overridden. Restoring previous value before saving.'); + TempResponseLength.restore(null); + } + //console.log('Entering settings with name1 = '+name1); return jQuery.ajax({ type: 'POST',