new model

This commit is contained in:
jason-on-salt-a40 2024-04-20 19:50:35 -07:00
parent 13e52470c3
commit b10a245b44
1 changed files with 13 additions and 7 deletions

View File

@ -77,8 +77,14 @@ class WhisperxModel:
def load_models(whisper_backend_name, whisper_model_name, alignment_model_name, voicecraft_model_name):
global transcribe_model, align_model, voicecraft_model
if voicecraft_model_name == "giga330M_TTSEnhanced":
if voicecraft_model_name == "330M":
voicecraft_model_name = "giga330M"
elif voicecraft_model_name == "830M":
voicecraft_model_name = "giga830M"
elif voicecraft_model_name == "330M_TTSEnhanced":
voicecraft_model_name = "gigaHalfLibri330M_TTSEnhanced_max16s"
elif voicecraft_model_name == "830M_TTSEnhanced":
voicecraft_model_name = "830M_TTSEnhanced"
if alignment_model_name is not None:
align_model = WhisperxAlignModel()
@ -434,19 +440,19 @@ def get_app():
with gr.Column(scale=5):
with gr.Accordion("Select models", open=False) as models_selector:
with gr.Row():
voicecraft_model_choice = gr.Radio(label="VoiceCraft model", value="giga830M",
choices=["giga330M", "giga830M", "giga330M_TTSEnhanced"])
whisper_backend_choice = gr.Radio(label="Whisper backend", value="whisperX", choices=["whisper", "whisperX"])
voicecraft_model_choice = gr.Radio(label="VoiceCraft model", value="830M_TTSEnhanced",
choices=["330M", "830M", "330M_TTSEnhanced", "830M_TTSEnhanced"])
whisper_backend_choice = gr.Radio(label="Whisper backend", value="whisperX", choices=["whisperX", "whisper"])
whisper_model_choice = gr.Radio(label="Whisper model", value="base.en",
choices=[None, "base.en", "small.en", "medium.en", "large"])
align_model_choice = gr.Radio(label="Forced alignment model", value="whisperX", choices=[None, "whisperX"])
align_model_choice = gr.Radio(label="Forced alignment model", value="whisperX", choices=["whisperX", None])
with gr.Row():
with gr.Column(scale=2):
input_audio = gr.Audio(value=f"{DEMO_PATH}/84_121550_000074_000000.wav", label="Input Audio", type="filepath", interactive=True)
with gr.Group():
original_transcript = gr.Textbox(label="Original transcript", lines=5, value=demo_original_transcript,
info="Use whisper model to get the transcript. Fix and align it if necessary.")
info="Use whisperx model to get the transcript. Fix and align it if necessary.")
with gr.Accordion("Word start time", open=False):
transcript_with_start_time = gr.Textbox(label="Start time", lines=5, interactive=False, info="Start time before each word")
with gr.Accordion("Word end time", open=False):
@ -499,7 +505,7 @@ def get_app():
with gr.Accordion("Generation Parameters - change these if you are unhappy with the generation", open=False):
stop_repetition = gr.Radio(label="stop_repetition", choices=[-1, 1, 2, 3, 4], value=3,
info="if there are long silence in the generated audio, reduce the stop_repetition to 2 or 1. -1 = disabled")
sample_batch_size = gr.Number(label="speech rate", value=4, precision=0,
sample_batch_size = gr.Number(label="speech rate", value=3, precision=0,
info="The higher the number, the faster the output will be. "
"Under the hood, the model will generate this many samples and choose the shortest one. "
"For giga330M_TTSEnhanced, 1 or 2 should be fine since the model is trained to do TTS.")