diff --git a/gradio_app.py b/gradio_app.py index 564af54..80d96ec 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -90,7 +90,7 @@ def load_models(whisper_backend_name, whisper_model_name, alignment_model_name, if align_model is None: raise gr.Error("Align model required for whisperx backend") transcribe_model = WhisperxModel(whisper_model_name, align_model) - + voicecraft_name = f"{voicecraft_model_name}.pth" ckpt_fn = f"./pretrained_models/{voicecraft_name}" encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th" @@ -132,7 +132,7 @@ def transcribe(seed, audio_path): if transcribe_model is None: raise gr.Error("Transcription model not loaded") seed_everything(seed) - + segments = transcribe_model.transcribe(audio_path) state = get_transcribe_state(segments) @@ -234,7 +234,7 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, if mode != "Edit": from inference_tts_scale import inference_one_sample - if smart_transcript: + if smart_transcript: target_transcript = "" for word in transcribe_state["words_info"]: if word["end"] < prompt_end_time: @@ -281,7 +281,7 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, morphed_span = (max(edit_start_time - left_margin, 1 / codec_sr), min(edit_end_time + right_margin, audio_dur)) mask_interval = [[round(morphed_span[0]*codec_sr), round(morphed_span[1]*codec_sr)]] mask_interval = torch.LongTensor(mask_interval) - + _, gen_audio = inference_one_sample(voicecraft_model["model"], voicecraft_model["ckpt"]["config"], voicecraft_model["ckpt"]["phn2num"], @@ -300,12 +300,12 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, output_audio = get_output_audio(previous_audio_tensors, codec_audio_sr) sentence_audio = get_output_audio(audio_tensors, codec_audio_sr) return output_audio, inference_transcript, sentence_audio, previous_audio_tensors - - + + def update_input_audio(audio_path): if audio_path is None: return 0, 0, 0 - + info = torchaudio.info(audio_path) max_time = round(info.num_frames / info.sample_rate, 2) return [ @@ -314,7 +314,7 @@ def update_input_audio(audio_path): gr.Slider(maximum=max_time, value=max_time), ] - + def change_mode(mode): tts_mode_controls, edit_mode_controls, edit_word_mode, split_text, long_tts_sentence_editor return [ @@ -416,7 +416,7 @@ demo_words_info = [ def update_demo(mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word): if transcript not in all_demo_texts: return transcript, edit_from_word, edit_to_word - + replace_half = edit_word_mode == "Replace half" change_edit_from_word = edit_from_word == demo_words[2] or edit_from_word == demo_words[3] change_edit_to_word = edit_to_word == demo_words[11] or edit_to_word == demo_words[12] @@ -456,7 +456,7 @@ with gr.Blocks() as app: transcribe_btn = gr.Button(value="Transcribe") align_btn = gr.Button(value="Align") - + with gr.Column(scale=3): with gr.Group(): transcript = gr.Textbox(label="Text", lines=7, value=demo_text["TTS"]["smart"]) @@ -471,7 +471,7 @@ with gr.Blocks() as app: info="Split text into parts and run TTS for each part.", visible=False) edit_word_mode = gr.Radio(label="Edit word mode", choices=["Replace half", "Replace all"], value="Replace half", info="What to do with first and last word", visible=False) - + with gr.Group() as tts_mode_controls: prompt_to_word = gr.Dropdown(label="Last word in prompt", choices=demo_words, value=demo_words[10], interactive=True) prompt_end_time = gr.Slider(label="Prompt end time", minimum=0, maximum=7.93, step=0.001, value=3.016) @@ -517,11 +517,11 @@ with gr.Blocks() as app: codec_sr = gr.Number(label="codec_sr", value=50, info='encodec specific, Do not change') silence_tokens = gr.Textbox(label="silence tokens", value="[1388,1898,131]", info="encodec specific, do not change") - + audio_tensors = gr.State() transcribe_state = gr.State(value={"words_info": demo_words_info}) - + mode.change(fn=update_demo, inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word], outputs=[transcript, edit_from_word, edit_to_word]) @@ -531,11 +531,11 @@ with gr.Blocks() as app: smart_transcript.change(fn=update_demo, inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word], outputs=[transcript, edit_from_word, edit_to_word]) - + load_models_btn.click(fn=load_models, inputs=[whisper_backend_choice, whisper_model_choice, align_model_choice, voicecraft_model_choice], outputs=[models_selector]) - + input_audio.upload(fn=update_input_audio, inputs=[input_audio], outputs=[prompt_end_time, edit_start_time, edit_end_time]) @@ -564,7 +564,7 @@ with gr.Blocks() as app: split_text, sentence_selector, audio_tensors ], outputs=[output_audio, inference_transcript, sentence_selector, audio_tensors]) - + sentence_selector.change(fn=load_sentence, inputs=[sentence_selector, codec_audio_sr, audio_tensors], outputs=[sentence_audio]) @@ -580,7 +580,7 @@ with gr.Blocks() as app: split_text, sentence_selector, audio_tensors ], outputs=[output_audio, inference_transcript, sentence_audio, audio_tensors]) - + prompt_to_word.change(fn=update_bound_word, inputs=[gr.State(False), prompt_to_word, gr.State("Replace all")], outputs=[prompt_end_time])