From 3d3f32ba7e3bfb3ef2d7514c6081911f06d5fb2b Mon Sep 17 00:00:00 2001 From: Stepan Zuev Date: Thu, 4 Apr 2024 22:22:27 +0300 Subject: [PATCH] whisperx added --- README.md | 2 +- gradio_app.py | 231 +++++++++++++++++++++++++++++----------- gradio_requirements.txt | 4 +- 3 files changed, 170 insertions(+), 67 deletions(-) diff --git a/README.md b/README.md index 2a5be3a..5e78082 100644 --- a/README.md +++ b/README.md @@ -111,7 +111,7 @@ It is ready to use on [default url](http://127.0.0.1:7860). 6. (optionally) Rerun part-by-part in Long TTS mode ### Some features -Smart transcript: write only what you want to generate, but don't work if you edit original transcript +Smart transcript: write only what you want to generate TTS mode: Zero-shot TTS diff --git a/gradio_app.py b/gradio_app.py index 707defa..4321a11 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -10,9 +10,16 @@ import os import io import numpy as np import random +import uuid -whisper_model, voicecraft_model = None, None +TMP_PATH = "./demo/temp" +device = "cuda" if torch.cuda.is_available() else "cpu" +whisper_model, align_model, voicecraft_model = None, None, None + + +def get_random_string(): + return "".join(str(uuid.uuid4()).split("-")) def seed_everything(seed): @@ -26,22 +33,63 @@ def seed_everything(seed): torch.backends.cudnn.deterministic = True -def load_models(whisper_model_choice, voicecraft_model_choice): - global whisper_model, voicecraft_model +class WhisperxAlignModel: + def __init__(self): + from whisperx import load_align_model + self.model, self.metadata = load_align_model(language_code="en", device=device) + + def align(self, segments, audio_path): + from whisperx import align, load_audio + audio = load_audio(audio_path) + return align(segments, self.model, self.metadata, audio, device, return_char_alignments=False)["segments"] + + +class WhisperModel: + def __init__(self, model_name): + from whisper import load_model + self.model = load_model(model_name, device) - if whisper_model_choice is not None: - import whisper from whisper.tokenizer import get_tokenizer - whisper_model = { - "model": whisper.load_model(whisper_model_choice), - "tokenizer": get_tokenizer(multilingual=False) - } + tokenizer = get_tokenizer(multilingual=False) + self.supress_tokens = [-1] + [ + i + for i in range(tokenizer.eot) + if all(c in "0123456789" for c in tokenizer.decode([i]).removeprefix(" ")) + ] + + def transcribe(self, audio_path): + return self.model.transcribe(audio_path, suppress_tokens=self.supress_tokens, word_timestamps=True)["segments"] + + +class WhisperxModel: + def __init__(self, model_name, align_model: WhisperxAlignModel): + from whisperx import load_model + self.model = load_model(model_name, device, asr_options={"suppress_numerals": True}) + self.align_model = align_model + + def transcribe(self, audio_path): + segments = self.model.transcribe(audio_path, batch_size=8)["segments"] + return self.align_model.align(segments, audio_path) + + +def load_models(whisper_backend_name, whisper_model_name, alignment_model_name, voicecraft_model_name): + global transcribe_model, align_model, voicecraft_model os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "0" - device = "cuda" if torch.cuda.is_available() else "cpu" + + if alignment_model_name is not None: + align_model = WhisperxAlignModel() + + if whisper_model_name is not None: + if whisper_backend_name == "whisper": + transcribe_model = WhisperModel(whisper_model_name) + else: + if align_model is None: + raise gr.Error("Align model required for whisperx backend") + transcribe_model = WhisperxModel(whisper_model_name, align_model) - voicecraft_name = f"{voicecraft_model_choice}.pth" + voicecraft_name = f"{voicecraft_model_name}.pth" ckpt_fn = f"./pretrained_models/{voicecraft_name}" encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th" if not os.path.exists(ckpt_fn): @@ -66,31 +114,78 @@ def load_models(whisper_model_choice, voicecraft_model_choice): return gr.Accordion() +def get_transcribe_state(segments): + words_info = [word_info for segment in segments for word_info in segment["words"]] + return { + "segments": segments, + "transcript": " ".join([segment["text"] for segment in segments]), + "words_info": words_info, + "transcript_with_start_time": " ".join([f"{word['start']} {word['word']}" for word in words_info]), + "transcript_with_end_time": " ".join([f"{word['word']} {word['end']}" for word in words_info]), + "word_bounds": [f"{word['start']} {word['word']} {word['end']}" for word in words_info] + } + + def transcribe(seed, audio_path): - if whisper_model is None: - raise gr.Error("Whisper model not loaded") + if transcribe_model is None: + raise gr.Error("Transcription model not loaded") seed_everything(seed) - number_tokens = [ - i - for i in range(whisper_model["tokenizer"].eot) - if all(c in "0123456789" for c in whisper_model["tokenizer"].decode([i]).removeprefix(" ")) - ] - result = whisper_model["model"].transcribe(audio_path, suppress_tokens=[-1] + number_tokens, word_timestamps=True) - words = [word_info for segment in result["segments"] for word_info in segment["words"]] - - transcript = result["text"] - transcript_with_start_time = " ".join([f"{word['start']} {word['word']}" for word in words]) - transcript_with_end_time = " ".join([f"{word['word']} {word['end']}" for word in words]) - - choices = [f"{word['start']} {word['word']} {word['end']}" for word in words] + segments = transcribe_model.transcribe(audio_path) + state = get_transcribe_state(segments) return [ - transcript, transcript_with_start_time, transcript_with_end_time, - gr.Dropdown(value=choices[-1], choices=choices, interactive=True), # prompt_to_word - gr.Dropdown(value=choices[0], choices=choices, interactive=True), # edit_from_word - gr.Dropdown(value=choices[-1], choices=choices, interactive=True), # edit_to_word - words + state["transcript"], state["transcript_with_start_time"], state["transcript_with_end_time"], + gr.Dropdown(value=state["word_bounds"][-1], choices=state["word_bounds"], interactive=True), # prompt_to_word + gr.Dropdown(value=state["word_bounds"][0], choices=state["word_bounds"], interactive=True), # edit_from_word + gr.Dropdown(value=state["word_bounds"][-1], choices=state["word_bounds"], interactive=True), # edit_to_word + state + ] + + +def align_segments(transcript, audio_path): + from aeneas.executetask import ExecuteTask + from aeneas.task import Task + import json + config_string = 'task_language=eng|os_task_file_format=json|is_text_type=plain' + + tmp_transcript_path = os.path.join(TMP_PATH, f"{get_random_string()}.txt") + tmp_sync_map_path = os.path.join(TMP_PATH, f"{get_random_string()}.json") + with open(tmp_transcript_path, "w") as f: + f.write(transcript) + + task = Task(config_string=config_string) + task.audio_file_path_absolute = os.path.abspath(audio_path) + task.text_file_path_absolute = os.path.abspath(tmp_transcript_path) + task.sync_map_file_path_absolute = os.path.abspath(tmp_sync_map_path) + ExecuteTask(task).execute() + task.output_sync_map_file() + + with open(tmp_sync_map_path, "r") as f: + return json.load(f) + + +def align(seed, transcript, audio_path): + if align_model is None: + raise gr.Error("Align model not loaded") + seed_everything(seed) + + fragments = align_segments(transcript, audio_path) + segments = [{ + "start": float(fragment["begin"]), + "end": float(fragment["end"]), + "text": " ".join(fragment["lines"]) + } for fragment in fragments["fragments"]] + segments = align_model.align(segments, audio_path) + state = get_transcribe_state(segments) + print(state) + + return [ + state["transcript_with_start_time"], state["transcript_with_end_time"], + gr.Dropdown(value=state["word_bounds"][-1], choices=state["word_bounds"], interactive=True), # prompt_to_word + gr.Dropdown(value=state["word_bounds"][0], choices=state["word_bounds"], interactive=True), # edit_from_word + gr.Dropdown(value=state["word_bounds"][-1], choices=state["word_bounds"], interactive=True), # edit_to_word + state ] @@ -104,12 +199,12 @@ def get_output_audio(audio_tensors, codec_audio_sr): def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature, stop_repetition, sample_batch_size, kvcache, silence_tokens, - audio_path, word_info, transcript, smart_transcript, + audio_path, transcribe_state, transcript, smart_transcript, mode, prompt_end_time, edit_start_time, edit_end_time, split_text, selected_sentence, previous_audio_tensors): if voicecraft_model is None: raise gr.Error("VoiceCraft model not loaded") - if smart_transcript and (word_info is None): + if smart_transcript and (transcribe_state is None): raise gr.Error("Can't use smart transcript: whisper transcript not found") seed_everything(seed) @@ -126,7 +221,6 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, else: sentences = [transcript.replace("\n", " ")] - device = "cuda" if torch.cuda.is_available() else "cpu" info = torchaudio.info(audio_path) audio_dur = info.num_frames / info.sample_rate @@ -141,7 +235,7 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, if smart_transcript: target_transcript = "" - for word in word_info: + for word in transcribe_state["words_info"]: if word["end"] < prompt_end_time: target_transcript += word["word"] elif (word["start"] + word["end"]) / 2 < prompt_end_time: @@ -169,13 +263,13 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, if smart_transcript: target_transcript = "" - for word in word_info: + for word in transcribe_state["words_info"]: if word["start"] < edit_start_time: target_transcript += word["word"] else: break target_transcript += f" {sentence}" - for word in word_info: + for word in transcribe_state["words_info"]: if word["end"] > edit_end_time: target_transcript += word["word"] else: @@ -296,25 +390,25 @@ demo_text = { all_demo_texts = {vv for k, v in demo_text.items() for kk, vv in v.items()} demo_words = [ - '0.0 But 0.12', '0.12 when 0.26', '0.26 I 0.44', '0.44 had 0.6', '0.6 approached 0.94', '0.94 so 1.42', - '1.42 near 1.78', '1.78 to 2.02', '2.02 them, 2.24', '2.52 the 2.58', '2.58 common 2.9', '2.9 object, 3.3', - '3.72 which 3.78', '3.78 the 3.98', '3.98 sense 4.18', '4.18 deceives, 4.88', '5.06 lost 5.26', '5.26 not 5.74', - '5.74 by 6.08', '6.08 distance 6.36', '6.36 any 6.92', '6.92 of 7.12', '7.12 its 7.26', '7.26 marks. 7.54' + '0.029 But 0.149', '0.189 when 0.33', '0.43 I 0.49', '0.53 had 0.65', '0.711 approached 1.152', '1.352 so 1.593', + '1.693 near 1.933', '1.994 to 2.074', '2.134 them, 2.354', '2.535 the 2.655', '2.695 common 3.016', '3.196 object, 3.577', + '3.717 which 3.898', '3.958 the 4.058', '4.098 sense 4.359', '4.419 deceives, 4.92', '5.101 lost 5.481', '5.682 not 5.963', + '6.043 by 6.183', '6.223 distance 6.644', '6.905 any 7.065', '7.125 of 7.185', '7.245 its 7.346', '7.406 marks. 7.727' ] -demo_word_info = [ - {'word': ' But', 'start': 0.0, 'end': 0.12}, {'word': ' when', 'start': 0.12, 'end': 0.26}, - {'word': ' I', 'start': 0.26, 'end': 0.44}, {'word': ' had', 'start': 0.44, 'end': 0.6}, - {'word': ' approached', 'start': 0.6, 'end': 0.94}, {'word': ' so', 'start': 0.94, 'end': 1.42}, - {'word': ' near', 'start': 1.42, 'end': 1.78}, {'word': ' to', 'start': 1.78, 'end': 2.02}, - {'word': ' them,', 'start': 2.02, 'end': 2.24}, {'word': ' the', 'start': 2.52, 'end': 2.58}, - {'word': ' common', 'start': 2.58, 'end': 2.9}, {'word': ' object,', 'start': 2.9, 'end': 3.3}, - {'word': ' which', 'start': 3.72, 'end': 3.78}, {'word': ' the', 'start': 3.78, 'end': 3.98}, - {'word': ' sense', 'start': 3.98, 'end': 4.18}, {'word': ' deceives,', 'start': 4.18, 'end': 4.88}, - {'word': ' lost', 'start': 5.06, 'end': 5.26}, {'word': ' not', 'start': 5.26, 'end': 5.74}, - {'word': ' by', 'start': 5.74, 'end': 6.08}, {'word': ' distance', 'start': 6.08, 'end': 6.36}, - {'word': ' any', 'start': 6.36, 'end': 6.92}, {'word': ' of', 'start': 6.92, 'end': 7.12}, - {'word': ' its', 'start': 7.12, 'end': 7.26}, {'word': ' marks.', 'start': 7.26, 'end': 7.54} +demo_words_info = [ + {'word': 'But', 'start': 0.029, 'end': 0.149, 'score': 0.834}, {'word': 'when', 'start': 0.189, 'end': 0.33, 'score': 0.879}, + {'word': 'I', 'start': 0.43, 'end': 0.49, 'score': 0.984}, {'word': 'had', 'start': 0.53, 'end': 0.65, 'score': 0.998}, + {'word': 'approached', 'start': 0.711, 'end': 1.152, 'score': 0.822}, {'word': 'so', 'start': 1.352, 'end': 1.593, 'score': 0.822}, + {'word': 'near', 'start': 1.693, 'end': 1.933, 'score': 0.752}, {'word': 'to', 'start': 1.994, 'end': 2.074, 'score': 0.924}, + {'word': 'them,', 'start': 2.134, 'end': 2.354, 'score': 0.914}, {'word': 'the', 'start': 2.535, 'end': 2.655, 'score': 0.818}, + {'word': 'common', 'start': 2.695, 'end': 3.016, 'score': 0.971}, {'word': 'object,', 'start': 3.196, 'end': 3.577, 'score': 0.823}, + {'word': 'which', 'start': 3.717, 'end': 3.898, 'score': 0.701}, {'word': 'the', 'start': 3.958, 'end': 4.058, 'score': 0.798}, + {'word': 'sense', 'start': 4.098, 'end': 4.359, 'score': 0.797}, {'word': 'deceives,', 'start': 4.419, 'end': 4.92, 'score': 0.802}, + {'word': 'lost', 'start': 5.101, 'end': 5.481, 'score': 0.71}, {'word': 'not', 'start': 5.682, 'end': 5.963, 'score': 0.781}, + {'word': 'by', 'start': 6.043, 'end': 6.183, 'score': 0.834}, {'word': 'distance', 'start': 6.223, 'end': 6.644, 'score': 0.899}, + {'word': 'any', 'start': 6.905, 'end': 7.065, 'score': 0.893}, {'word': 'of', 'start': 7.125, 'end': 7.185, 'score': 0.772}, + {'word': 'its', 'start': 7.245, 'end': 7.346, 'score': 0.778}, {'word': 'marks.', 'start': 7.406, 'end': 7.727, 'score': 0.955} ] @@ -342,21 +436,24 @@ with gr.Blocks() as app: with gr.Accordion("Select models", open=False) as models_selector: with gr.Row(): voicecraft_model_choice = gr.Radio(label="VoiceCraft model", value="giga830M", choices=["giga330M", "giga830M"]) + whisper_backend_choice = gr.Radio(label="Whisper backend", value="whisperX", choices=["whisper", "whisperX"]) whisper_model_choice = gr.Radio(label="Whisper model", value="base.en", - choices=[None, "tiny.en", "base.en", "small.en", "medium.en", "large"]) + choices=[None, "base.en", "small.en", "medium.en", "large"]) + align_model_choice = gr.Radio(label="Forced alignment model", value="whisperX", choices=[None, "whisperX"]) with gr.Row(): with gr.Column(scale=2): input_audio = gr.Audio(value="./demo/84_121550_000074_000000.wav", label="Input Audio", type="filepath") with gr.Group(): - original_transcript = gr.Textbox(label="Original transcript", lines=5, value=demo_original_transcript, interactive=False, - info="Use whisper model to get the transcript. Fix it if necessary.") + original_transcript = gr.Textbox(label="Original transcript", lines=5, value=demo_original_transcript, + info="Use whisper model to get the transcript. Fix and align it if necessary.") with gr.Accordion("Word start time", open=False): transcript_with_start_time = gr.Textbox(label="Start time", lines=5, interactive=False, info="Start time before each word") with gr.Accordion("Word end time", open=False): transcript_with_end_time = gr.Textbox(label="End time", lines=5, interactive=False, info="End time after each word") transcribe_btn = gr.Button(value="Transcribe") + align_btn = gr.Button(value="Align") with gr.Column(scale=3): with gr.Group(): @@ -375,15 +472,15 @@ with gr.Blocks() as app: with gr.Group() as tts_mode_controls: prompt_to_word = gr.Dropdown(label="Last word in prompt", choices=demo_words, value=demo_words[10], interactive=True) - prompt_end_time = gr.Slider(label="Prompt end time", minimum=0, maximum=7.93, step=0.01, value=3.01) + prompt_end_time = gr.Slider(label="Prompt end time", minimum=0, maximum=7.93, step=0.001, value=3.016) with gr.Group(visible=False) as edit_mode_controls: with gr.Row(): edit_from_word = gr.Dropdown(label="First word to edit", choices=demo_words, value=demo_words[2], interactive=True) edit_to_word = gr.Dropdown(label="Last word to edit", choices=demo_words, value=demo_words[12], interactive=True) with gr.Row(): - edit_start_time = gr.Slider(label="Edit from time", minimum=0, maximum=7.93, step=0.01, value=0.35) - edit_end_time = gr.Slider(label="Edit to time", minimum=0, maximum=7.93, step=0.01, value=3.75) + edit_start_time = gr.Slider(label="Edit from time", minimum=0, maximum=7.93, step=0.001, value=0.46) + edit_end_time = gr.Slider(label="Edit to time", minimum=0, maximum=7.93, step=0.001, value=3.808) run_btn = gr.Button(value="Run") @@ -418,7 +515,7 @@ with gr.Blocks() as app: audio_tensors = gr.State() - word_info = gr.State(value=demo_word_info) + transcribe_state = gr.State(value={"words_info": demo_words_info}) mode.change(fn=update_demo, @@ -432,7 +529,7 @@ with gr.Blocks() as app: outputs=[transcript, edit_from_word, edit_to_word]) load_models_btn.click(fn=load_models, - inputs=[whisper_model_choice, voicecraft_model_choice], + inputs=[whisper_backend_choice, whisper_model_choice, align_model_choice, voicecraft_model_choice], outputs=[models_selector]) input_audio.upload(fn=update_input_audio, @@ -441,7 +538,11 @@ with gr.Blocks() as app: transcribe_btn.click(fn=transcribe, inputs=[seed, input_audio], outputs=[original_transcript, transcript_with_start_time, transcript_with_end_time, - prompt_to_word, edit_from_word, edit_to_word, word_info]) + prompt_to_word, edit_from_word, edit_to_word, transcribe_state]) + align_btn.click(fn=align, + inputs=[seed, original_transcript, input_audio], + outputs=[transcript_with_start_time, transcript_with_end_time, + prompt_to_word, edit_from_word, edit_to_word, transcribe_state]) mode.change(fn=change_mode, inputs=[mode], @@ -454,7 +555,7 @@ with gr.Blocks() as app: top_k, top_p, temperature, stop_repetition, sample_batch_size, kvcache, silence_tokens, - input_audio, word_info, transcript, smart_transcript, + input_audio, transcribe_state, transcript, smart_transcript, mode, prompt_end_time, edit_start_time, edit_end_time, split_text, sentence_selector, audio_tensors ], @@ -470,7 +571,7 @@ with gr.Blocks() as app: top_k, top_p, temperature, stop_repetition, sample_batch_size, kvcache, silence_tokens, - input_audio, word_info, transcript, smart_transcript, + input_audio, transcribe_state, transcript, smart_transcript, gr.State(value="Rerun"), prompt_end_time, edit_start_time, edit_end_time, split_text, sentence_selector, audio_tensors ], diff --git a/gradio_requirements.txt b/gradio_requirements.txt index 949acc7..967b3d7 100644 --- a/gradio_requirements.txt +++ b/gradio_requirements.txt @@ -1,3 +1,5 @@ gradio==3.50.2 nltk>=3.8.1 -openai-whisper>=20231117 \ No newline at end of file +openai-whisper>=20231117 +aeneas>=1.7.3.0 +whisperx>=3.1.1 \ No newline at end of file