Kokoro: chunk generation, add pre-process func

#3412
This commit is contained in:
Cohee
2025-03-12 21:35:09 +02:00
parent e65b72ea41
commit 1b817cd897
2 changed files with 52 additions and 22 deletions

View File

@ -1,5 +1,5 @@
import { debounce_timeout } from '../../constants.js'; import { debounce_timeout } from '../../constants.js';
import { debounceAsync } from '../../utils.js'; import { debounceAsync, splitRecursive } from '../../utils.js';
import { getPreviewString, saveTtsProviderSettings } from './index.js'; import { getPreviewString, saveTtsProviderSettings } from './index.js';
export class KokoroTtsProvider { export class KokoroTtsProvider {
@ -52,6 +52,17 @@ export class KokoroTtsProvider {
this.initTtsDebounced = debounceAsync(this.initializeWorker.bind(this), debounce_timeout.relaxed); this.initTtsDebounced = debounceAsync(this.initializeWorker.bind(this), debounce_timeout.relaxed);
} }
/**
* Perform any text processing before passing to TTS engine.
* @param {string} text Input text
* @returns {string} Processed text
*/
processText(text) {
// TILDE!
text = text.replace(/~/g, '.');
return text;
}
async loadSettings(settings) { async loadSettings(settings) {
if (settings.modelId !== undefined) this.settings.modelId = settings.modelId; if (settings.modelId !== undefined) this.settings.modelId = settings.modelId;
if (settings.dtype !== undefined) this.settings.dtype = settings.dtype; if (settings.dtype !== undefined) this.settings.dtype = settings.dtype;
@ -258,13 +269,17 @@ export class KokoroTtsProvider {
const voice = this.getVoice(voiceId); const voice = this.getVoice(voiceId);
const previewText = getPreviewString(voice.lang); const previewText = getPreviewString(voice.lang);
const response = await this.generateTts(previewText, voiceId); for await (const response of this.generateTts(previewText, voiceId)) {
const audio = await response.blob(); const audio = await response.blob();
const url = URL.createObjectURL(audio); const url = URL.createObjectURL(audio);
await new Promise(resolve => {
const audioElement = new Audio(); const audioElement = new Audio();
audioElement.src = url; audioElement.src = url;
audioElement.play(); audioElement.play();
audioElement.onended = () => URL.revokeObjectURL(url); audioElement.onended = () => resolve();
});
URL.revokeObjectURL(url);
}
} }
getVoiceDisplayName(voiceId) { getVoiceDisplayName(voiceId) {
@ -282,7 +297,13 @@ export class KokoroTtsProvider {
}; };
} }
async generateTts(text, voiceId) { /**
* Generate TTS audio for the given text using the specified voice.
* @param {string} text Text to generate
* @param {string} voiceId Voice ID
* @returns {AsyncGenerator<Response>} Audio response generator
*/
async* generateTts(text, voiceId) {
if (!this.ready || !this.worker) { if (!this.ready || !this.worker) {
console.log('TTS not ready, initializing...'); console.log('TTS not ready, initializing...');
await this.initializeWorker(); await this.initializeWorker();
@ -299,7 +320,11 @@ export class KokoroTtsProvider {
const voice = this.getVoice(voiceId); const voice = this.getVoice(voiceId);
const requestId = this.nextRequestId++; const requestId = this.nextRequestId++;
return new Promise((resolve, reject) => { const chunkSize = 400;
const chunks = splitRecursive(text, chunkSize, ['\n\n', '\n', '.', '?', '!', ',', ' ', '']);
for (const chunk of chunks) {
yield await new Promise((resolve, reject) => {
// Store the promise callbacks // Store the promise callbacks
this.pendingRequests.set(requestId, { resolve, reject }); this.pendingRequests.set(requestId, { resolve, reject });
@ -307,7 +332,7 @@ export class KokoroTtsProvider {
this.worker.postMessage({ this.worker.postMessage({
action: 'generateTts', action: 'generateTts',
data: { data: {
text, text: chunk,
voice: voice.voice_id, voice: voice.voice_id,
speakingRate: this.settings.speakingRate || 1.0, speakingRate: this.settings.speakingRate || 1.0,
requestId, requestId,
@ -315,6 +340,7 @@ export class KokoroTtsProvider {
}); });
}); });
} }
}
dispose() { dispose() {
// Clean up the worker when the provider is disposed // Clean up the worker when the provider is disposed

View File

@ -1015,6 +1015,10 @@ export function splitRecursive(input, length, delimiters = ['\n\n', '\n', ' ', '
return result; return result;
} }
export function splitSentences(input, length) {
var pattRegex = new RegExp(`^[\\s\\S]{${Math.floor(length / 2)},${length}}[.!?,]{1}|^[\\s\\S]{1,${length}}$|^[\\s\\S]{1,${length}}`);
}
/** /**
* Checks if a string is a valid data URL. * Checks if a string is a valid data URL.
* @param {string} str The string to check. * @param {string} str The string to check.