VoiceCraft/gradio_app.py

437 lines
23 KiB
Python
Raw Normal View History

2024-04-02 16:58:15 +02:00
import gradio as gr
import torch
import torchaudio
from data.tokenizer import (
AudioTokenizer,
TextTokenizer,
)
from models import voicecraft
import whisper
from whisper.tokenizer import get_tokenizer
import os
import io
whisper_model = None
voicecraft_model = None
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_models(input_audio, transcribe_btn, run_btn, rerun_btn):
def impl(whisper_model_choice, voicecraft_model_choice):
global whisper_model, voicecraft_model
whisper_model = whisper.load_model(whisper_model_choice)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
voicecraft_name = f"{voicecraft_model_choice}.pth"
ckpt_fn = f"./pretrained_models/{voicecraft_name}"
encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th"
if not os.path.exists(ckpt_fn):
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\?download\=true")
os.system(f"mv {voicecraft_name}\?download\=true ./pretrained_models/{voicecraft_name}")
if not os.path.exists(encodec_fn):
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th")
os.system(f"mv encodec_4cb2048_giga.th ./pretrained_models/encodec_4cb2048_giga.th")
voicecraft_model = {}
voicecraft_model["ckpt"] = torch.load(ckpt_fn, map_location="cpu")
voicecraft_model["model"] = voicecraft.VoiceCraft(voicecraft_model["ckpt"]["config"])
voicecraft_model["model"].load_state_dict(voicecraft_model["ckpt"]["model"])
voicecraft_model["model"].to(device)
voicecraft_model["model"].eval()
voicecraft_model["text_tokenizer"] = TextTokenizer(backend="espeak")
voicecraft_model["audio_tokenizer"] = AudioTokenizer(signature=encodec_fn)
return [
input_audio.update(interactive=True),
transcribe_btn.update(interactive=True),
run_btn.update(interactive=True),
rerun_btn.update(interactive=True)
]
return impl
def transcribe(audio_path):
tokenizer = get_tokenizer(multilingual=False)
number_tokens = [
i
for i in range(tokenizer.eot)
if all(c in "0123456789" for c in tokenizer.decode([i]).removeprefix(" "))
]
result = whisper_model.transcribe(audio_path, suppress_tokens=[-1] + number_tokens, word_timestamps=True)
words = [word_info for segment in result["segments"] for word_info in segment["words"]]
transcript = result["text"]
transcript_with_start_time = " ".join([f"{word['start']} {word['word']}" for word in words])
transcript_with_end_time = " ".join([f"{word['word']} {word['end']}" for word in words])
choices = [f"{word['start']} {word['word']} {word['end']}" for word in words]
edit_from_word = gr.Dropdown(label="First word to edit", value=choices[0], choices=choices, interactive=True)
edit_to_word = gr.Dropdown(label="Last word to edit", value=choices[-1], choices=choices, interactive=True)
return [
transcript, transcript_with_start_time, transcript_with_end_time,
edit_from_word, edit_to_word, words
]
def get_output_audio(audio_tensors, codec_audio_sr):
result = torch.cat(audio_tensors, 1)
buffer = io.BytesIO()
torchaudio.save(buffer, result, int(codec_audio_sr), format="wav")
buffer.seek(0)
return buffer.read()
def run(left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature,
stop_repetition, sample_batch_size, kvcache, silence_tokens,
audio_path, word_info, transcript, smart_transcript,
mode, prompt_end_time, edit_start_time, edit_end_time,
split_text, selected_sentence, previous_audio_tensors):
if mode == "Long TTS":
if split_text == "Newline":
sentences = transcript.split('\n')
else:
from nltk.tokenize import sent_tokenize
sentences = sent_tokenize(transcript.replace("\n", " "))
elif mode == "Rerun":
colon_position = selected_sentence.find(':')
selected_sentence_idx = int(selected_sentence[:colon_position])
sentences = [selected_sentence[colon_position + 1:]]
else:
sentences = [transcript.replace("\n", " ")]
info = torchaudio.info(audio_path)
audio_dur = info.num_frames / info.sample_rate
audio_tensors = []
inference_transcript = ""
for sentence in sentences:
decode_config = {"top_k": top_k, "top_p": top_p, "temperature": temperature, "stop_repetition": stop_repetition,
"kvcache": kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr,
"silence_tokens": silence_tokens, "sample_batch_size": sample_batch_size}
if mode != "Edit":
from inference_tts_scale import inference_one_sample
if smart_transcript:
target_transcript = ""
for word in word_info:
if word["end"] < prompt_end_time:
target_transcript += word["word"]
elif (word["start"] + word["end"]) / 2 < prompt_end_time:
# include part of the word it it's big, but adjust prompt_end_time
target_transcript += word["word"]
prompt_end_time = word["end"]
break
else:
break
target_transcript += f" {sentence}"
else:
target_transcript = sentence
inference_transcript += target_transcript + "\n"
prompt_end_frame = int(min(audio_dur, prompt_end_time) * info.sample_rate)
_, gen_audio = inference_one_sample(voicecraft_model["model"],
voicecraft_model["ckpt"]["config"],
voicecraft_model["ckpt"]["phn2num"],
voicecraft_model["text_tokenizer"], voicecraft_model["audio_tokenizer"],
audio_path, target_transcript, device, decode_config,
prompt_end_frame)
else:
from inference_speech_editing_scale import inference_one_sample
if smart_transcript:
target_transcript = ""
for word in word_info:
if word["start"] < edit_start_time:
target_transcript += word["word"]
else:
break
target_transcript += f" {sentence}"
for word in word_info:
if word["end"] > edit_end_time:
target_transcript += word["word"]
else:
target_transcript = sentence
inference_transcript += target_transcript + "\n"
morphed_span = (max(edit_start_time - left_margin, 1 / codec_sr), min(edit_end_time + right_margin, audio_dur))
mask_interval = [[round(morphed_span[0]*codec_sr), round(morphed_span[1]*codec_sr)]]
mask_interval = torch.LongTensor(mask_interval)
_, gen_audio = inference_one_sample(voicecraft_model["model"],
voicecraft_model["ckpt"]["config"],
voicecraft_model["ckpt"]["phn2num"],
voicecraft_model["text_tokenizer"], voicecraft_model["audio_tokenizer"],
audio_path, target_transcript, mask_interval, device, decode_config)
gen_audio = gen_audio[0].cpu()
audio_tensors.append(gen_audio)
if mode != "Rerun":
output_audio = get_output_audio(audio_tensors, codec_audio_sr)
sentences = [f"{idx}: {text}" for idx, text in enumerate(sentences)]
component = gr.Dropdown(label="Sentence", choices=sentences, value=sentences[0],
info="Select sentence you want to regenerate")
return output_audio, inference_transcript, component, audio_tensors
else:
previous_audio_tensors[selected_sentence_idx] = audio_tensors[0]
output_audio = get_output_audio(previous_audio_tensors, codec_audio_sr)
sentence_audio = get_output_audio(audio_tensors, codec_audio_sr)
return output_audio, inference_transcript, sentence_audio, previous_audio_tensors
def update_input_audio(prompt_end_time, edit_start_time, edit_end_time):
def impl(audio_path):
info = torchaudio.info(audio_path)
max_time = round(info.num_frames / info.sample_rate, 2)
return [
prompt_end_time.update(maximum=max_time, value=max_time),
edit_start_time.update(maximum=max_time, value=0),
edit_end_time.update(maximum=max_time, value=max_time),
]
return impl
def change_mode(prompt_end_time, split_text, edit_word_mode, segment_control, precise_segment_control, long_tts_controls):
def impl(mode):
return [
prompt_end_time.update(visible=mode != "Edit"),
split_text.update(visible=mode == "Long TTS"),
edit_word_mode.update(visible=mode == "Edit"),
segment_control.update(visible=mode == "Edit"),
precise_segment_control.update(visible=mode == "Edit"),
long_tts_controls.update(visible=mode == "Long TTS"),
]
return impl
def load_sentence(selected_sentence, codec_audio_sr, audio_tensors):
if selected_sentence is None:
return None
colon_position = selected_sentence.find(':')
selected_sentence_idx = int(selected_sentence[:colon_position])
return get_output_audio([audio_tensors[selected_sentence_idx]], codec_audio_sr)
def update_bound_word(is_first_word, edit_time):
def impl(selected_word, edit_word_mode):
word_start_time = float(selected_word.split(' ')[0])
word_end_time = float(selected_word.split(' ')[-1])
if edit_word_mode == "Replace half":
bound_time = (word_start_time + word_end_time) / 2
elif is_first_word:
bound_time = word_start_time
else:
bound_time = word_end_time
return edit_time.update(value=bound_time)
return impl
def update_bound_words(edit_start_time, edit_end_time):
def impl(from_selected_word, to_selected_word, edit_word_mode):
return [
update_bound_word(True, edit_start_time)(from_selected_word, edit_word_mode),
update_bound_word(True, edit_end_time)(to_selected_word, edit_word_mode),
]
return impl
smart_transcript_info = """
If enabled, the target transcript will be constructed for you:</br>
- In TTS and Long TTS mode just write the text you want to synthesize.</br>
- In Edit mode just write the text to replace selected editing segment.</br>
If disabled, you should write the target transcript yourself:</br>
- In TTS mode write prompt transcript followed by generation transcript.</br>
- In Long TTS select split by newline (<b>SENTENCE SPLIT WON'T WORK</b>) and start each line with a prompt transcript.</br>
- In Edit mode write full prompt</br>
"""
demo_text = {
"TTS": {
"smart": "I cannot believe that the same model can also do text to speech synthesis as well!",
"regular": "But when I had approached so near to them, the common I cannot believe that the same model can also do text to speech synthesis as well!"
},
"Edit": {
"smart": "saw the mirage of the lake in the distance,",
"regular": "But when I saw the mirage of the lake in the distance, which the sense deceives, Lost not by distance any of its marks,"
},
"Long TTS": {
"smart": "You can run TTS on a big text!\n"
"Just write it line-by-line. Or sentence-by-sentence.\n"
"If some sentences sound odd, just rerun TTS on them, no need to generate the whole text again!",
"regular": "But when I had approached so near to them, the common You can run TTS on a big text!\n"
"But when I had approached so near to them, the common Just write it line-by-line. Or sentence-by-sentence.\n"
"But when I had approached so near to them, the common If some sentences sound odd, just rerun TTS on them, no need to generate the whole text again!"
}
}
all_demo_texts = {vv for k, v in demo_text.items() for kk, vv in v.items()}
demo_words = [
'0.0 But 0.12', '0.12 when 0.26', '0.26 I 0.44', '0.44 had 0.6', '0.6 approached 0.94', '0.94 so 1.42',
'1.42 near 1.78', '1.78 to 2.02', '2.02 them, 2.24', '2.52 the 2.58', '2.58 common 2.9', '2.9 object, 3.3',
'3.72 which 3.78', '3.78 the 3.98', '3.98 sense 4.18', '4.18 deceives, 4.88', '5.06 lost 5.26', '5.26 not 5.74',
'5.74 by 6.08', '6.08 distance 6.36', '6.36 any 6.92', '6.92 of 7.12', '7.12 its 7.26', '7.26 marks. 7.54'
]
def update_demo(transcript, edit_from_word, edit_to_word, prompt_end_time):
def impl(mode, smart_transcript, edit_word_mode):
if transcript.value not in all_demo_texts:
return [transcript, edit_from_word, edit_to_word, prompt_end_time]
replace_half = edit_word_mode == "Replace half"
return [
transcript.update(value=demo_text[mode]["smart" if smart_transcript else "regular"]),
edit_from_word.update(value="0.26 I 0.44" if replace_half else "0.44 had 0.6"),
edit_to_word.update(value="3.72 which 3.78" if replace_half else "2.9 object, 3.3"),
prompt_end_time.update(value=3.01),
]
return impl
with gr.Blocks() as app:
with gr.Row():
with gr.Column(scale=2):
load_models_btn = gr.Button(value="Load models")
with gr.Column(scale=5):
with gr.Accordion("Select models", open=False):
with gr.Row():
voicecraft_model_choice = gr.Radio(label="VoiceCraft model", value="giga830M", choices=["giga330M", "giga830M"])
whisper_model_choice = gr.Radio(label="Whisper model", value="base.en",
choices=["tiny.en", "base.en", "small.en", "medium.en", "large"])
with gr.Row():
with gr.Column(scale=2):
input_audio = gr.Audio(value="./demo/84_121550_000074_000000.wav", label="Input Audio", type="filepath", interactive=False)
with gr.Group():
original_transcript = gr.Textbox(label="Original transcript", lines=5, interactive=False,
info="Use whisper model to get the transcript. Fix it if necessary.")
with gr.Accordion("Word start time", open=False):
transcript_with_start_time = gr.Textbox(label="Start time", lines=5, interactive=False, info="Start time before each word")
with gr.Accordion("Word end time", open=False):
transcript_with_end_time = gr.Textbox(label="End time", lines=5, interactive=False, info="End time after each word")
transcribe_btn = gr.Button(value="Transcribe", interactive=False)
with gr.Column(scale=3):
with gr.Group():
transcript = gr.Textbox(label="Text", lines=7, value=demo_text["TTS"]["smart"])
with gr.Row():
smart_transcript = gr.Checkbox(label="Smart transcript", value=True)
with gr.Accordion(label="?", open=False):
info = gr.HTML(value=smart_transcript_info)
mode = gr.Radio(label="Mode", choices=["TTS", "Edit", "Long TTS"], value="TTS")
prompt_end_time = gr.Slider(label="Prompt end time", minimum=0, maximum=7.93, step=0.01, value=3.01)
split_text = gr.Radio(label="Split text", choices=["Newline", "Sentence"], value="Newline", visible=False,
info="Split text into parts and run TTS for each part.")
edit_word_mode = gr.Radio(label="Edit word mode", choices=["Replace half", "Replace all"], value="Replace half", visible=False,
info="What to do with first and last word")
with gr.Row(visible=False) as segment_control:
edit_from_word = gr.Dropdown(label="First word to edit", choices=demo_words, interactive=True)
edit_to_word = gr.Dropdown(label="Last word to edit", choices=demo_words, interactive=True)
with gr.Accordion("Precise segment control", open=False, visible=False) as precise_segment_control:
edit_start_time = gr.Slider(label="Edit from time", minimum=0, maximum=60, step=0.01, value=0)
edit_end_time = gr.Slider(label="Edit to time", minimum=0, maximum=60, step=0.01, value=60)
run_btn = gr.Button(value="Run", interactive=False)
with gr.Column(scale=2):
output_audio = gr.Audio(label="Output Audio")
with gr.Accordion("Inference transcript", open=False):
inference_transcript = gr.Textbox(label="Inference transcript", lines=5, interactive=False,
info="Inference was performed on this transcript.")
with gr.Group(visible=False) as long_tts_controls:
sentence_selector = gr.Dropdown(label="Sentence", value=None,
info="Select sentence you want to regenerate")
sentence_audio = gr.Audio(label="Sentence Audio", scale=2)
rerun_btn = gr.Button(value="Rerun", interactive=False)
with gr.Row():
with gr.Accordion("VoiceCraft config", open=False):
left_margin = gr.Number(label="left_margin", value=0.08)
right_margin = gr.Number(label="right_margin", value=0.08)
codec_audio_sr = gr.Number(label="codec_audio_sr", value=16000)
codec_sr = gr.Number(label="codec_sr", value=50)
top_k = gr.Number(label="top_k", value=0)
top_p = gr.Number(label="top_p", value=0.8)
temperature = gr.Number(label="temperature", value=1)
stop_repetition = gr.Radio(label="stop_repetition", choices=[-1, 1, 2, 3], value=3,
info="if there are long silence in the generated audio, reduce the stop_repetition to 3, 2 or even 1, -1 = disabled")
sample_batch_size = gr.Number(label="sample_batch_size", value=4, precision=0,
info="generate this many samples and choose the shortest one")
kvcache = gr.Radio(label="kvcache", choices=[0, 1], value=1,
info="set to 0 to use less VRAM, but with slower inference")
silence_tokens = gr.Textbox(label="silence tokens", value="[1388,1898,131]")
audio_tensors = gr.State()
word_info = gr.State()
mode.change(fn=update_demo(transcript, edit_from_word, edit_to_word, prompt_end_time),
inputs=[mode, smart_transcript, edit_word_mode],
outputs=[transcript, edit_from_word, edit_to_word, prompt_end_time])
edit_word_mode.change(fn=update_demo(transcript, edit_from_word, edit_to_word, prompt_end_time),
inputs=[mode, smart_transcript, edit_word_mode],
outputs=[transcript, edit_from_word, edit_to_word, prompt_end_time])
smart_transcript.change(fn=update_demo(transcript, edit_from_word, edit_to_word, prompt_end_time),
inputs=[mode, smart_transcript, edit_word_mode],
outputs=[transcript, edit_from_word, edit_to_word, prompt_end_time])
load_models_btn.click(fn=load_models(input_audio, transcribe_btn, run_btn, rerun_btn),
inputs=[whisper_model_choice, voicecraft_model_choice],
outputs=[input_audio, transcribe_btn, run_btn, rerun_btn])
input_audio.change(fn=update_input_audio(prompt_end_time, edit_start_time, edit_end_time),
inputs=[input_audio], outputs=[prompt_end_time, edit_start_time, edit_end_time])
transcribe_btn.click(fn=transcribe, inputs=[input_audio],
outputs=[original_transcript, transcript_with_start_time, transcript_with_end_time, edit_from_word, edit_to_word, word_info])
mode.change(fn=change_mode(prompt_end_time, split_text, edit_word_mode, segment_control, precise_segment_control, long_tts_controls),
inputs=[mode], outputs=[prompt_end_time, split_text, edit_word_mode, segment_control, precise_segment_control, long_tts_controls])
run_btn.click(fn=run,
inputs=[
left_margin, right_margin,
codec_audio_sr, codec_sr,
top_k, top_p, temperature,
stop_repetition, sample_batch_size,
kvcache, silence_tokens,
input_audio, word_info, transcript, smart_transcript,
mode, prompt_end_time, edit_start_time, edit_end_time,
split_text, sentence_selector, audio_tensors
],
outputs=[
output_audio, inference_transcript, sentence_selector, audio_tensors
])
sentence_selector.change(fn=load_sentence, inputs=[sentence_selector, codec_audio_sr, audio_tensors], outputs=[sentence_audio])
rerun_btn.click(fn=run,
inputs=[
left_margin, right_margin,
codec_audio_sr, codec_sr,
top_k, top_p, temperature,
stop_repetition, sample_batch_size,
kvcache, silence_tokens,
input_audio, word_info, transcript, smart_transcript,
gr.State(value="Rerun"), prompt_end_time, edit_start_time, edit_end_time,
split_text, sentence_selector, audio_tensors
],
outputs=[
output_audio, inference_transcript, sentence_audio, audio_tensors
])
edit_word_mode.change(fn=update_bound_words(edit_start_time, edit_end_time),
inputs=[edit_from_word, edit_to_word, edit_word_mode], outputs=[edit_start_time, edit_end_time])
edit_from_word.change(fn=update_bound_word(True, edit_start_time),
inputs=[edit_from_word, edit_word_mode], outputs=[edit_start_time])
edit_to_word.change(fn=update_bound_word(False, edit_end_time),
inputs=[edit_to_word, edit_word_mode], outputs=[edit_end_time])
if __name__ == "__main__":
app.launch()