From b10a245b44638343cbfbcc939a583eeb34814c49 Mon Sep 17 00:00:00 2001 From: jason-on-salt-a40 Date: Sat, 20 Apr 2024 19:50:35 -0700 Subject: [PATCH] new model --- gradio_app.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/gradio_app.py b/gradio_app.py index 496ce50..b9887a3 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -77,8 +77,14 @@ class WhisperxModel: def load_models(whisper_backend_name, whisper_model_name, alignment_model_name, voicecraft_model_name): global transcribe_model, align_model, voicecraft_model - if voicecraft_model_name == "giga330M_TTSEnhanced": + if voicecraft_model_name == "330M": + voicecraft_model_name = "giga330M" + elif voicecraft_model_name == "830M": + voicecraft_model_name = "giga830M" + elif voicecraft_model_name == "330M_TTSEnhanced": voicecraft_model_name = "gigaHalfLibri330M_TTSEnhanced_max16s" + elif voicecraft_model_name == "830M_TTSEnhanced": + voicecraft_model_name = "830M_TTSEnhanced" if alignment_model_name is not None: align_model = WhisperxAlignModel() @@ -434,19 +440,19 @@ def get_app(): with gr.Column(scale=5): with gr.Accordion("Select models", open=False) as models_selector: with gr.Row(): - voicecraft_model_choice = gr.Radio(label="VoiceCraft model", value="giga830M", - choices=["giga330M", "giga830M", "giga330M_TTSEnhanced"]) - whisper_backend_choice = gr.Radio(label="Whisper backend", value="whisperX", choices=["whisper", "whisperX"]) + voicecraft_model_choice = gr.Radio(label="VoiceCraft model", value="830M_TTSEnhanced", + choices=["330M", "830M", "330M_TTSEnhanced", "830M_TTSEnhanced"]) + whisper_backend_choice = gr.Radio(label="Whisper backend", value="whisperX", choices=["whisperX", "whisper"]) whisper_model_choice = gr.Radio(label="Whisper model", value="base.en", choices=[None, "base.en", "small.en", "medium.en", "large"]) - align_model_choice = gr.Radio(label="Forced alignment model", value="whisperX", choices=[None, "whisperX"]) + align_model_choice = gr.Radio(label="Forced alignment model", value="whisperX", choices=["whisperX", None]) with gr.Row(): with gr.Column(scale=2): input_audio = gr.Audio(value=f"{DEMO_PATH}/84_121550_000074_000000.wav", label="Input Audio", type="filepath", interactive=True) with gr.Group(): original_transcript = gr.Textbox(label="Original transcript", lines=5, value=demo_original_transcript, - info="Use whisper model to get the transcript. Fix and align it if necessary.") + info="Use whisperx model to get the transcript. Fix and align it if necessary.") with gr.Accordion("Word start time", open=False): transcript_with_start_time = gr.Textbox(label="Start time", lines=5, interactive=False, info="Start time before each word") with gr.Accordion("Word end time", open=False): @@ -499,7 +505,7 @@ def get_app(): with gr.Accordion("Generation Parameters - change these if you are unhappy with the generation", open=False): stop_repetition = gr.Radio(label="stop_repetition", choices=[-1, 1, 2, 3, 4], value=3, info="if there are long silence in the generated audio, reduce the stop_repetition to 2 or 1. -1 = disabled") - sample_batch_size = gr.Number(label="speech rate", value=4, precision=0, + sample_batch_size = gr.Number(label="speech rate", value=3, precision=0, info="The higher the number, the faster the output will be. " "Under the hood, the model will generate this many samples and choose the shortest one. " "For giga330M_TTSEnhanced, 1 or 2 should be fine since the model is trained to do TTS.")