Refactoring StreamProcessor -> ReasoningHandler

This commit is contained in:
Wolfsblvt
2025-02-09 01:26:01 +01:00
parent 8b4414b799
commit d8eeab0c00
2 changed files with 169 additions and 80 deletions

View File

@ -269,7 +269,7 @@ import { initSettingsSearch } from './scripts/setting-search.js';
import { initBulkEdit } from './scripts/bulk-edit.js';
import { deriveTemplatesFromChatTemplate } from './scripts/chat-templates.js';
import { getContext } from './scripts/st-context.js';
import { extractReasoningFromData, initReasoning, isHiddenReasoningModel, PromptReasoning, updateReasoningUI } from './scripts/reasoning.js';
import { extractReasoningFromData, initReasoning, PromptReasoning, ReasoningHandler, updateReasoningUI } from './scripts/reasoning.js';
// API OBJECT FOR EXTERNAL WIRING
globalThis.SillyTavern = {
@ -3128,10 +3128,6 @@ class StreamingProcessor {
this.messageTimerDom = null;
/** @type {HTMLElement} */
this.messageTokenCounterDom = null;
/** @type {HTMLElement} */
this.messageReasoningDom = null;
/** @type {HTMLElement} */
this.messageReasoningHeaderDom = null;
/** @type {HTMLTextAreaElement} */
this.sendTextarea = document.querySelector('#send_textarea');
this.type = type;
@ -3147,20 +3143,8 @@ class StreamingProcessor {
/** @type {import('./scripts/logprobs.js').TokenLogprobs[]} */
this.messageLogprobs = [];
this.toolCalls = [];
this.reasoning = '';
/** @type {Date} */
this.reasoningStartTime = null;
/** @type {Date} */
this.reasoningEndTime = null;
this.isHiddenReasoning = isHiddenReasoningModel();
}
/** @type {() => number} Reasoning duration in milliseconds */
#reasoningDuration() {
if (this.reasoningStartTime && this.reasoningEndTime) {
return (this.reasoningEndTime.getTime() - this.reasoningStartTime.getTime());
}
return null;
// Initialize reasoning in its own handler
this.reasoningHandler = new ReasoningHandler(type, timeStarted);
}
#checkDomElements(messageId) {
@ -3169,11 +3153,8 @@ class StreamingProcessor {
this.messageTextDom = this.messageDom?.querySelector('.mes_text');
this.messageTimerDom = this.messageDom?.querySelector('.mes_timer');
this.messageTokenCounterDom = this.messageDom?.querySelector('.tokenCounterDisplay');
this.messageReasoningDom = this.messageDom?.querySelector('.mes_reasoning');
this.messageReasoningHeaderDom = this.messageDom?.querySelector('.mes_reasoning_header_title');
}
this.messageDom.classList.toggle('reasoning_hidden', this.isHiddenReasoning);
this.reasoningHandler.checkDomElements(messageId);
}
#updateMessageBlockVisibility() {
@ -3184,19 +3165,11 @@ class StreamingProcessor {
}
showMessageButtons(messageId) {
if (messageId == -1) {
return;
}
showStopButton();
$(`#chat .mes[mesid="${messageId}"] .mes_buttons`).css({ 'display': 'none' });
}
hideMessageButtons(messageId) {
if (messageId == -1) {
return;
}
hideStopButton();
$(`#chat .mes[mesid="${messageId}"] .mes_buttons`).css({ 'display': 'flex' });
}
@ -3207,14 +3180,12 @@ class StreamingProcessor {
if (this.type == 'impersonate') {
this.sendTextarea.value = '';
this.sendTextarea.dispatchEvent(new Event('input', { bubbles: true }));
}
else {
} else {
await saveReply(this.type, text, true, '', [], '');
messageId = chat.length - 1;
this.#checkDomElements(messageId);
this.showMessageButtons(messageId);
}
hideSwipeButtons();
scrollChatToBottom();
return messageId;
@ -3232,11 +3203,9 @@ class StreamingProcessor {
let processedText = cleanUpMessage(text, isImpersonate, isContinue, !isFinal, this.stoppingStrings);
// Predict unbalanced asterisks / quotes during streaming
const charsToBalance = ['*', '"', '```'];
for (const char of charsToBalance) {
if (!isFinal && isOdd(countOccurrences(processedText, char))) {
// Add character at the end to balance it
const separator = char.length > 1 ? '\n' : '';
processedText = processedText.trimEnd() + separator + char;
}
@ -3245,48 +3214,24 @@ class StreamingProcessor {
if (isImpersonate) {
this.sendTextarea.value = processedText;
this.sendTextarea.dispatchEvent(new Event('input', { bubbles: true }));
}
else {
} else {
const mesChanged = chat[messageId]['mes'] !== processedText;
this.#checkDomElements(messageId);
this.#updateMessageBlockVisibility();
const currentTime = new Date();
chat[messageId]['mes'] = processedText;
chat[messageId]['gen_started'] = this.timeStarted;
chat[messageId]['gen_finished'] = currentTime;
if (!chat[messageId]['extra']) {
chat[messageId]['extra'] = {};
}
if (this.reasoning || this.isHiddenReasoning) {
const reasoning = power_user.trim_spaces ? this.reasoning.trim() : this.reasoning;
const reasoningChanged = chat[messageId]['extra']['reasoning'] !== reasoning;
chat[messageId]['extra']['reasoning'] = reasoning;
// Update reasoning
await this.reasoningHandler.process(messageId, mesChanged, currentTime);
if ((this.isHiddenReasoning || reasoningChanged) && this.reasoningStartTime === null) {
this.reasoningStartTime = this.timeStarted;
}
if ((this.isHiddenReasoning || !reasoningChanged) && mesChanged && this.reasoningStartTime !== null && this.reasoningEndTime === null) {
this.reasoningEndTime = currentTime;
await eventSource.emit(event_types.STREAM_REASONING_DONE, this.reasoning, this.#reasoningDuration);
}
await this.#updateReasoningTime(messageId);
if (this.messageReasoningDom instanceof HTMLElement) {
const formattedReasoning = messageFormatting(this.reasoning, '', false, false, messageId, {}, true);
this.messageReasoningDom.innerHTML = formattedReasoning;
}
if (this.messageDom instanceof HTMLElement) {
this.messageDom.classList.add('reasoning');
}
}
// Don't waste time calculating token count for streaming
const tokenCountText = (this.reasoning || '') + processedText;
// Token count update.
const tokenCountText = this.reasoningHandler.reasoning + processedText;
const currentTokenCount = isFinal && power_user.message_token_count_enabled ? getTokenCount(tokenCountText, 0) : 0;
if (currentTokenCount) {
chat[messageId]['extra']['token_count'] = currentTokenCount;
if (this.messageTokenCounterDom instanceof HTMLElement) {
@ -3312,7 +3257,7 @@ class StreamingProcessor {
this.messageTextDom.innerHTML = formattedText;
}
const timePassed = formatGenerationTimer(this.timeStarted, currentTime, currentTokenCount, this.#reasoningDuration());
const timePassed = formatGenerationTimer(this.timeStarted, currentTime, currentTokenCount, this.reasoningHandler.getDuration());
if (this.messageTimerDom instanceof HTMLElement) {
this.messageTimerDom.textContent = timePassed.timerValue;
this.messageTimerDom.title = timePassed.timerTitle;
@ -3326,23 +3271,12 @@ class StreamingProcessor {
}
}
async #updateReasoningTime(messageId, { forceEnd = false } = {}) {
const duration = this.#reasoningDuration();
chat[messageId]['extra']['reasoning_duration'] = duration;
updateReasoningUI(this.messageDom, this.reasoning, duration, { forceEnd: forceEnd });
}
async onFinishStreaming(messageId, text) {
this.hideMessageButtons(this.messageId);
await this.onProgressStreaming(messageId, text, true);
addCopyToCodeBlocks($(`#chat .mes[mesid="${messageId}"]`));
// Ensure reasoning finish time is recorded if not already
if (this.reasoningStartTime !== null && this.reasoningEndTime === null) {
this.reasoningEndTime = new Date();
await eventSource.emit(event_types.STREAM_REASONING_DONE, this.reasoning, this.#reasoningDuration);
await this.#updateReasoningTime(messageId, { forceEnd: true });
}
await this.reasoningHandler.finish(messageId);
if (Array.isArray(this.swipes) && this.swipes.length > 0) {
const message = chat[messageId];
@ -3443,7 +3377,8 @@ class StreamingProcessor {
if (logprobs) {
this.messageLogprobs.push(...(Array.isArray(logprobs) ? logprobs : [logprobs]));
}
this.reasoning = getRegexedString(state?.reasoning ?? '', regex_placement.REASONING);
// Get the updated reasoning string into the handler
this.reasoningHandler.updateReasoning(state?.reasoning ?? '');
await eventSource.emit(event_types.STREAM_TOKEN_RECEIVED, text);
await sw.tick(async () => await this.onProgressStreaming(this.messageId, this.continueMessage + text));
}