bugfixes, seed support, better ui

This commit is contained in:
Stepan Zuev
2024-04-03 20:24:34 +03:00
parent f9fed26b15
commit 1a219cf6da

View File

@@ -8,11 +8,24 @@ from data.tokenizer import (
from models import voicecraft from models import voicecraft
import os import os
import io import io
import numpy as np
import random
whisper_model, voicecraft_model = None, None whisper_model, voicecraft_model = None, None
def seed_everything(seed):
if seed != -1:
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def load_models(whisper_model_choice, voicecraft_model_choice): def load_models(whisper_model_choice, voicecraft_model_choice):
global whisper_model, voicecraft_model global whisper_model, voicecraft_model
@@ -50,12 +63,13 @@ def load_models(whisper_model_choice, voicecraft_model_choice):
"audio_tokenizer": AudioTokenizer(signature=encodec_fn) "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
} }
return gr.Audio(interactive=True) return gr.Accordion()
def transcribe(audio_path): def transcribe(seed, audio_path):
if whisper_model is None: if whisper_model is None:
raise gr.Error("Whisper model not loaded") raise gr.Error("Whisper model not loaded")
seed_everything(seed)
number_tokens = [ number_tokens = [
i i
@@ -73,6 +87,7 @@ def transcribe(audio_path):
return [ return [
transcript, transcript_with_start_time, transcript_with_end_time, transcript, transcript_with_start_time, transcript_with_end_time,
gr.Dropdown(value=choices[-1], choices=choices, interactive=True), # prompt_to_word
gr.Dropdown(value=choices[0], choices=choices, interactive=True), # edit_from_word gr.Dropdown(value=choices[0], choices=choices, interactive=True), # edit_from_word
gr.Dropdown(value=choices[-1], choices=choices, interactive=True), # edit_to_word gr.Dropdown(value=choices[-1], choices=choices, interactive=True), # edit_to_word
words words
@@ -87,7 +102,7 @@ def get_output_audio(audio_tensors, codec_audio_sr):
return buffer.read() return buffer.read()
def run(left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature, def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature,
stop_repetition, sample_batch_size, kvcache, silence_tokens, stop_repetition, sample_batch_size, kvcache, silence_tokens,
audio_path, word_info, transcript, smart_transcript, audio_path, word_info, transcript, smart_transcript,
mode, prompt_end_time, edit_start_time, edit_end_time, mode, prompt_end_time, edit_start_time, edit_end_time,
@@ -97,6 +112,7 @@ def run(left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, tempe
if smart_transcript and (word_info is None): if smart_transcript and (word_info is None):
raise gr.Error("Can't use smart transcript: whisper transcript not found") raise gr.Error("Can't use smart transcript: whisper transcript not found")
seed_everything(seed)
if mode == "Long TTS": if mode == "Long TTS":
if split_text == "Newline": if split_text == "Newline":
sentences = transcript.split('\n') sentences = transcript.split('\n')
@@ -192,6 +208,9 @@ def run(left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, tempe
def update_input_audio(audio_path): def update_input_audio(audio_path):
if audio_path is None:
return 0, 0, 0
info = torchaudio.info(audio_path) info = torchaudio.info(audio_path)
max_time = round(info.num_frames / info.sample_rate, 2) max_time = round(info.num_frames / info.sample_rate, 2)
return [ return [
@@ -202,12 +221,12 @@ def update_input_audio(audio_path):
def change_mode(mode): def change_mode(mode):
tts_mode_controls, edit_mode_controls, edit_word_mode, split_text, long_tts_sentence_editor
return [ return [
gr.Slider(visible=mode != "Edit"), gr.Group(visible=mode != "Edit"),
gr.Radio(visible=mode == "Long TTS"), gr.Group(visible=mode == "Edit"),
gr.Radio(visible=mode == "Edit"), gr.Radio(visible=mode == "Edit"),
gr.Row(visible=mode == "Edit"), gr.Radio(visible=mode == "Long TTS"),
gr.Accordion(visible=mode == "Edit"),
gr.Group(visible=mode == "Long TTS"), gr.Group(visible=mode == "Long TTS"),
] ]
@@ -253,6 +272,8 @@ If disabled, you should write the target transcript yourself:</br>
- In Edit mode write full prompt</br> - In Edit mode write full prompt</br>
""" """
demo_original_transcript = " But when I had approached so near to them, the common object, which the sense deceives, lost not by distance any of its marks."
demo_text = { demo_text = {
"TTS": { "TTS": {
"smart": "I cannot believe that the same model can also do text to speech synthesis as well!", "smart": "I cannot believe that the same model can also do text to speech synthesis as well!",
@@ -281,17 +302,35 @@ demo_words = [
'5.74 by 6.08', '6.08 distance 6.36', '6.36 any 6.92', '6.92 of 7.12', '7.12 its 7.26', '7.26 marks. 7.54' '5.74 by 6.08', '6.08 distance 6.36', '6.36 any 6.92', '6.92 of 7.12', '7.12 its 7.26', '7.26 marks. 7.54'
] ]
demo_word_info = [
{'word': ' But', 'start': 0.0, 'end': 0.12}, {'word': ' when', 'start': 0.12, 'end': 0.26},
{'word': ' I', 'start': 0.26, 'end': 0.44}, {'word': ' had', 'start': 0.44, 'end': 0.6},
{'word': ' approached', 'start': 0.6, 'end': 0.94}, {'word': ' so', 'start': 0.94, 'end': 1.42},
{'word': ' near', 'start': 1.42, 'end': 1.78}, {'word': ' to', 'start': 1.78, 'end': 2.02},
{'word': ' them,', 'start': 2.02, 'end': 2.24}, {'word': ' the', 'start': 2.52, 'end': 2.58},
{'word': ' common', 'start': 2.58, 'end': 2.9}, {'word': ' object,', 'start': 2.9, 'end': 3.3},
{'word': ' which', 'start': 3.72, 'end': 3.78}, {'word': ' the', 'start': 3.78, 'end': 3.98},
{'word': ' sense', 'start': 3.98, 'end': 4.18}, {'word': ' deceives,', 'start': 4.18, 'end': 4.88},
{'word': ' lost', 'start': 5.06, 'end': 5.26}, {'word': ' not', 'start': 5.26, 'end': 5.74},
{'word': ' by', 'start': 5.74, 'end': 6.08}, {'word': ' distance', 'start': 6.08, 'end': 6.36},
{'word': ' any', 'start': 6.36, 'end': 6.92}, {'word': ' of', 'start': 6.92, 'end': 7.12},
{'word': ' its', 'start': 7.12, 'end': 7.26}, {'word': ' marks.', 'start': 7.26, 'end': 7.54}
]
def update_demo(mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word, prompt_end_time):
def update_demo(mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word):
if transcript not in all_demo_texts: if transcript not in all_demo_texts:
return transcript, edit_from_word, edit_to_word, prompt_end_time return transcript, edit_from_word, edit_to_word
replace_half = edit_word_mode == "Replace half" replace_half = edit_word_mode == "Replace half"
change_edit_from_word = edit_from_word == demo_words[2] or edit_from_word == demo_words[3]
change_edit_to_word = edit_to_word == demo_words[11] or edit_to_word == demo_words[12]
demo_edit_from_word_value = demo_words[2] if replace_half else demo_words[3]
demo_edit_to_word_value = demo_words[12] if replace_half else demo_words[11]
return [ return [
demo_text[mode]["smart" if smart_transcript else "regular"], demo_text[mode]["smart" if smart_transcript else "regular"],
"0.26 I 0.44" if replace_half else "0.44 had 0.6", demo_edit_from_word_value if change_edit_from_word else edit_from_word,
"3.72 which 3.78" if replace_half else "2.9 object, 3.3", demo_edit_to_word_value if change_edit_to_word else edit_to_word,
3.01,
] ]
@@ -300,7 +339,7 @@ with gr.Blocks() as app:
with gr.Column(scale=2): with gr.Column(scale=2):
load_models_btn = gr.Button(value="Load models") load_models_btn = gr.Button(value="Load models")
with gr.Column(scale=5): with gr.Column(scale=5):
with gr.Accordion("Select models", open=False): with gr.Accordion("Select models", open=False) as models_selector:
with gr.Row(): with gr.Row():
voicecraft_model_choice = gr.Radio(label="VoiceCraft model", value="giga830M", choices=["giga330M", "giga830M"]) voicecraft_model_choice = gr.Radio(label="VoiceCraft model", value="giga830M", choices=["giga330M", "giga830M"])
whisper_model_choice = gr.Radio(label="Whisper model", value="base.en", whisper_model_choice = gr.Radio(label="Whisper model", value="base.en",
@@ -308,9 +347,9 @@ with gr.Blocks() as app:
with gr.Row(): with gr.Row():
with gr.Column(scale=2): with gr.Column(scale=2):
input_audio = gr.Audio(value="./demo/84_121550_000074_000000.wav", label="Input Audio", type="filepath", interactive=False) input_audio = gr.Audio(value="./demo/84_121550_000074_000000.wav", label="Input Audio", type="filepath")
with gr.Group(): with gr.Group():
original_transcript = gr.Textbox(label="Original transcript", lines=5, interactive=False, original_transcript = gr.Textbox(label="Original transcript", lines=5, value=demo_original_transcript, interactive=False,
info="Use whisper model to get the transcript. Fix it if necessary.") info="Use whisper model to get the transcript. Fix it if necessary.")
with gr.Accordion("Word start time", open=False): with gr.Accordion("Word start time", open=False):
transcript_with_start_time = gr.Textbox(label="Start time", lines=5, interactive=False, info="Start time before each word") transcript_with_start_time = gr.Textbox(label="Start time", lines=5, interactive=False, info="Start time before each word")
@@ -325,20 +364,26 @@ with gr.Blocks() as app:
with gr.Row(): with gr.Row():
smart_transcript = gr.Checkbox(label="Smart transcript", value=True) smart_transcript = gr.Checkbox(label="Smart transcript", value=True)
with gr.Accordion(label="?", open=False): with gr.Accordion(label="?", open=False):
info = gr.HTML(value=smart_transcript_info) info = gr.Markdown(value=smart_transcript_info)
mode = gr.Radio(label="Mode", choices=["TTS", "Edit", "Long TTS"], value="TTS")
with gr.Row():
mode = gr.Radio(label="Mode", choices=["TTS", "Edit", "Long TTS"], value="TTS")
split_text = gr.Radio(label="Split text", choices=["Newline", "Sentence"], value="Newline",
info="Split text into parts and run TTS for each part.", visible=False)
edit_word_mode = gr.Radio(label="Edit word mode", choices=["Replace half", "Replace all"], value="Replace half",
info="What to do with first and last word", visible=False)
with gr.Group() as tts_mode_controls:
prompt_to_word = gr.Dropdown(label="Last word in prompt", choices=demo_words, value=demo_words[10], interactive=True)
prompt_end_time = gr.Slider(label="Prompt end time", minimum=0, maximum=7.93, step=0.01, value=3.01) prompt_end_time = gr.Slider(label="Prompt end time", minimum=0, maximum=7.93, step=0.01, value=3.01)
split_text = gr.Radio(label="Split text", choices=["Newline", "Sentence"], value="Newline", visible=False,
info="Split text into parts and run TTS for each part.") with gr.Group(visible=False) as edit_mode_controls:
edit_word_mode = gr.Radio(label="Edit word mode", choices=["Replace half", "Replace all"], value="Replace half", visible=False, with gr.Row():
info="What to do with first and last word") edit_from_word = gr.Dropdown(label="First word to edit", choices=demo_words, value=demo_words[2], interactive=True)
with gr.Row(visible=False) as segment_control: edit_to_word = gr.Dropdown(label="Last word to edit", choices=demo_words, value=demo_words[12], interactive=True)
edit_from_word = gr.Dropdown(label="First word to edit", choices=demo_words, interactive=True) with gr.Row():
edit_to_word = gr.Dropdown(label="Last word to edit", choices=demo_words, interactive=True) edit_start_time = gr.Slider(label="Edit from time", minimum=0, maximum=7.93, step=0.01, value=0.35)
with gr.Accordion("Precise segment control", open=False, visible=False) as precise_segment_control: edit_end_time = gr.Slider(label="Edit to time", minimum=0, maximum=7.93, step=0.01, value=3.75)
edit_start_time = gr.Slider(label="Edit from time", minimum=0, maximum=60, step=0.01, value=0)
edit_end_time = gr.Slider(label="Edit to time", minimum=0, maximum=60, step=0.01, value=60)
run_btn = gr.Button(value="Run") run_btn = gr.Button(value="Run")
@@ -347,7 +392,7 @@ with gr.Blocks() as app:
with gr.Accordion("Inference transcript", open=False): with gr.Accordion("Inference transcript", open=False):
inference_transcript = gr.Textbox(label="Inference transcript", lines=5, interactive=False, inference_transcript = gr.Textbox(label="Inference transcript", lines=5, interactive=False,
info="Inference was performed on this transcript.") info="Inference was performed on this transcript.")
with gr.Group(visible=False) as long_tts_controls: with gr.Group(visible=False) as long_tts_sentence_editor:
sentence_selector = gr.Dropdown(label="Sentence", value=None, sentence_selector = gr.Dropdown(label="Sentence", value=None,
info="Select sentence you want to regenerate") info="Select sentence you want to regenerate")
sentence_audio = gr.Audio(label="Sentence Audio", scale=2) sentence_audio = gr.Audio(label="Sentence Audio", scale=2)
@@ -355,6 +400,7 @@ with gr.Blocks() as app:
with gr.Row(): with gr.Row():
with gr.Accordion("VoiceCraft config", open=False): with gr.Accordion("VoiceCraft config", open=False):
seed = gr.Number(label="seed", value=-1, precision=0)
left_margin = gr.Number(label="left_margin", value=0.08) left_margin = gr.Number(label="left_margin", value=0.08)
right_margin = gr.Number(label="right_margin", value=0.08) right_margin = gr.Number(label="right_margin", value=0.08)
codec_audio_sr = gr.Number(label="codec_audio_sr", value=16000) codec_audio_sr = gr.Number(label="codec_audio_sr", value=16000)
@@ -372,37 +418,38 @@ with gr.Blocks() as app:
audio_tensors = gr.State() audio_tensors = gr.State()
word_info = gr.State() word_info = gr.State(value=demo_word_info)
mode.change(fn=update_demo, mode.change(fn=update_demo,
inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word, prompt_end_time], inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word],
outputs=[transcript, edit_from_word, edit_to_word, prompt_end_time]) outputs=[transcript, edit_from_word, edit_to_word])
edit_word_mode.change(fn=update_demo, edit_word_mode.change(fn=update_demo,
inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word, prompt_end_time], inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word],
outputs=[transcript, edit_from_word, edit_to_word, prompt_end_time]) outputs=[transcript, edit_from_word, edit_to_word])
smart_transcript.change(fn=update_demo, smart_transcript.change(fn=update_demo,
inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word, prompt_end_time], inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word],
outputs=[transcript, edit_from_word, edit_to_word, prompt_end_time]) outputs=[transcript, edit_from_word, edit_to_word])
load_models_btn.click(fn=load_models, load_models_btn.click(fn=load_models,
inputs=[whisper_model_choice, voicecraft_model_choice], inputs=[whisper_model_choice, voicecraft_model_choice],
outputs=[input_audio]) outputs=[models_selector])
input_audio.change(fn=update_input_audio, input_audio.upload(fn=update_input_audio,
inputs=[input_audio], inputs=[input_audio],
outputs=[prompt_end_time, edit_start_time, edit_end_time]) outputs=[prompt_end_time, edit_start_time, edit_end_time])
transcribe_btn.click(fn=transcribe, transcribe_btn.click(fn=transcribe,
inputs=[input_audio], inputs=[seed, input_audio],
outputs=[original_transcript, transcript_with_start_time, transcript_with_end_time, edit_from_word, edit_to_word, word_info]) outputs=[original_transcript, transcript_with_start_time, transcript_with_end_time,
prompt_to_word, edit_from_word, edit_to_word, word_info])
mode.change(fn=change_mode, mode.change(fn=change_mode,
inputs=[mode], inputs=[mode],
outputs=[prompt_end_time, split_text, edit_word_mode, segment_control, precise_segment_control, long_tts_controls]) outputs=[tts_mode_controls, edit_mode_controls, edit_word_mode, split_text, long_tts_sentence_editor])
run_btn.click(fn=run, run_btn.click(fn=run,
inputs=[ inputs=[
left_margin, right_margin, seed, left_margin, right_margin,
codec_audio_sr, codec_sr, codec_audio_sr, codec_sr,
top_k, top_p, temperature, top_k, top_p, temperature,
stop_repetition, sample_batch_size, stop_repetition, sample_batch_size,
@@ -418,7 +465,7 @@ with gr.Blocks() as app:
outputs=[sentence_audio]) outputs=[sentence_audio])
rerun_btn.click(fn=run, rerun_btn.click(fn=run,
inputs=[ inputs=[
left_margin, right_margin, seed, left_margin, right_margin,
codec_audio_sr, codec_sr, codec_audio_sr, codec_sr,
top_k, top_p, temperature, top_k, top_p, temperature,
stop_repetition, sample_batch_size, stop_repetition, sample_batch_size,
@@ -429,6 +476,9 @@ with gr.Blocks() as app:
], ],
outputs=[output_audio, inference_transcript, sentence_audio, audio_tensors]) outputs=[output_audio, inference_transcript, sentence_audio, audio_tensors])
prompt_to_word.change(fn=update_bound_word,
inputs=[gr.State(False), prompt_to_word, gr.State("Replace all")],
outputs=[prompt_end_time])
edit_from_word.change(fn=update_bound_word, edit_from_word.change(fn=update_bound_word,
inputs=[gr.State(True), edit_from_word, edit_word_mode], inputs=[gr.State(True), edit_from_word, edit_word_mode],
outputs=[edit_start_time]) outputs=[edit_start_time])