Kokoro: chunk generation, add pre-process func

#3412
This commit is contained in:
Cohee
2025-03-12 21:35:09 +02:00
parent e65b72ea41
commit 1b817cd897
2 changed files with 52 additions and 22 deletions

View File

@ -1,5 +1,5 @@
import { debounce_timeout } from '../../constants.js'; import { debounce_timeout } from '../../constants.js';
import { debounceAsync } from '../../utils.js'; import { debounceAsync, splitRecursive } from '../../utils.js';
import { getPreviewString, saveTtsProviderSettings } from './index.js'; import { getPreviewString, saveTtsProviderSettings } from './index.js';
export class KokoroTtsProvider { export class KokoroTtsProvider {
@ -52,6 +52,17 @@ export class KokoroTtsProvider {
this.initTtsDebounced = debounceAsync(this.initializeWorker.bind(this), debounce_timeout.relaxed); this.initTtsDebounced = debounceAsync(this.initializeWorker.bind(this), debounce_timeout.relaxed);
} }
/**
* Perform any text processing before passing to TTS engine.
* @param {string} text Input text
* @returns {string} Processed text
*/
processText(text) {
// TILDE!
text = text.replace(/~/g, '.');
return text;
}
async loadSettings(settings) { async loadSettings(settings) {
if (settings.modelId !== undefined) this.settings.modelId = settings.modelId; if (settings.modelId !== undefined) this.settings.modelId = settings.modelId;
if (settings.dtype !== undefined) this.settings.dtype = settings.dtype; if (settings.dtype !== undefined) this.settings.dtype = settings.dtype;
@ -258,13 +269,17 @@ export class KokoroTtsProvider {
const voice = this.getVoice(voiceId); const voice = this.getVoice(voiceId);
const previewText = getPreviewString(voice.lang); const previewText = getPreviewString(voice.lang);
const response = await this.generateTts(previewText, voiceId); for await (const response of this.generateTts(previewText, voiceId)) {
const audio = await response.blob(); const audio = await response.blob();
const url = URL.createObjectURL(audio); const url = URL.createObjectURL(audio);
const audioElement = new Audio(); await new Promise(resolve => {
audioElement.src = url; const audioElement = new Audio();
audioElement.play(); audioElement.src = url;
audioElement.onended = () => URL.revokeObjectURL(url); audioElement.play();
audioElement.onended = () => resolve();
});
URL.revokeObjectURL(url);
}
} }
getVoiceDisplayName(voiceId) { getVoiceDisplayName(voiceId) {
@ -282,7 +297,13 @@ export class KokoroTtsProvider {
}; };
} }
async generateTts(text, voiceId) { /**
* Generate TTS audio for the given text using the specified voice.
* @param {string} text Text to generate
* @param {string} voiceId Voice ID
* @returns {AsyncGenerator<Response>} Audio response generator
*/
async* generateTts(text, voiceId) {
if (!this.ready || !this.worker) { if (!this.ready || !this.worker) {
console.log('TTS not ready, initializing...'); console.log('TTS not ready, initializing...');
await this.initializeWorker(); await this.initializeWorker();
@ -299,21 +320,26 @@ export class KokoroTtsProvider {
const voice = this.getVoice(voiceId); const voice = this.getVoice(voiceId);
const requestId = this.nextRequestId++; const requestId = this.nextRequestId++;
return new Promise((resolve, reject) => { const chunkSize = 400;
// Store the promise callbacks const chunks = splitRecursive(text, chunkSize, ['\n\n', '\n', '.', '?', '!', ',', ' ', '']);
this.pendingRequests.set(requestId, { resolve, reject });
// Send the request to the worker for (const chunk of chunks) {
this.worker.postMessage({ yield await new Promise((resolve, reject) => {
action: 'generateTts', // Store the promise callbacks
data: { this.pendingRequests.set(requestId, { resolve, reject });
text,
voice: voice.voice_id, // Send the request to the worker
speakingRate: this.settings.speakingRate || 1.0, this.worker.postMessage({
requestId, action: 'generateTts',
}, data: {
text: chunk,
voice: voice.voice_id,
speakingRate: this.settings.speakingRate || 1.0,
requestId,
},
});
}); });
}); }
} }
dispose() { dispose() {

View File

@ -1015,6 +1015,10 @@ export function splitRecursive(input, length, delimiters = ['\n\n', '\n', ' ', '
return result; return result;
} }
export function splitSentences(input, length) {
var pattRegex = new RegExp(`^[\\s\\S]{${Math.floor(length / 2)},${length}}[.!?,]{1}|^[\\s\\S]{1,${length}}$|^[\\s\\S]{1,${length}}`);
}
/** /**
* Checks if a string is a valid data URL. * Checks if a string is a valid data URL.
* @param {string} str The string to check. * @param {string} str The string to check.