deprecated .update not used anymore, better error handling, can use voicecraft without whisper

This commit is contained in:
Stepan Zuev 2024-04-03 05:01:55 +03:00
parent 5cef625c1b
commit 74fa65979d

View File

@ -6,62 +6,63 @@ from data.tokenizer import (
TextTokenizer,
)
from models import voicecraft
import whisper
from whisper.tokenizer import get_tokenizer
import os
import io
whisper_model = None
voicecraft_model = None
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_models(whisper_model_choice, voicecraft_model_choice):
whisper_model, voicecraft_model = None, None
if whisper_model_choice is not None:
import whisper
from whisper.tokenizer import get_tokenizer
whisper_model = {
"model": whisper.load_model(whisper_model_choice),
"tokenizer": get_tokenizer(multilingual=False)
}
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = "cuda" if torch.cuda.is_available() else "cpu"
voicecraft_name = f"{voicecraft_model_choice}.pth"
ckpt_fn = f"./pretrained_models/{voicecraft_name}"
encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th"
if not os.path.exists(ckpt_fn):
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\?download\=true")
os.system(f"mv {voicecraft_name}\?download\=true ./pretrained_models/{voicecraft_name}")
if not os.path.exists(encodec_fn):
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th")
os.system(f"mv encodec_4cb2048_giga.th ./pretrained_models/encodec_4cb2048_giga.th")
ckpt = torch.load(ckpt_fn, map_location="cpu")
model = voicecraft.VoiceCraft(ckpt["config"])
model.load_state_dict(ckpt["model"])
model.to(device)
model.eval()
voicecraft_model = {
"ckpt": ckpt,
"model": model,
"text_tokenizer": TextTokenizer(backend="espeak"),
"audio_tokenizer": AudioTokenizer(signature=encodec_fn)
}
return [
whisper_model,
voicecraft_model,
gr.Audio(interactive=True),
]
def load_models(input_audio, transcribe_btn, run_btn, rerun_btn):
def impl(whisper_model_choice, voicecraft_model_choice):
global whisper_model, voicecraft_model
whisper_model = whisper.load_model(whisper_model_choice)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
voicecraft_name = f"{voicecraft_model_choice}.pth"
ckpt_fn = f"./pretrained_models/{voicecraft_name}"
encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th"
if not os.path.exists(ckpt_fn):
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\?download\=true")
os.system(f"mv {voicecraft_name}\?download\=true ./pretrained_models/{voicecraft_name}")
if not os.path.exists(encodec_fn):
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th")
os.system(f"mv encodec_4cb2048_giga.th ./pretrained_models/encodec_4cb2048_giga.th")
voicecraft_model = {}
voicecraft_model["ckpt"] = torch.load(ckpt_fn, map_location="cpu")
voicecraft_model["model"] = voicecraft.VoiceCraft(voicecraft_model["ckpt"]["config"])
voicecraft_model["model"].load_state_dict(voicecraft_model["ckpt"]["model"])
voicecraft_model["model"].to(device)
voicecraft_model["model"].eval()
voicecraft_model["text_tokenizer"] = TextTokenizer(backend="espeak")
voicecraft_model["audio_tokenizer"] = AudioTokenizer(signature=encodec_fn)
return [
input_audio.update(interactive=True),
transcribe_btn.update(interactive=True),
run_btn.update(interactive=True),
rerun_btn.update(interactive=True)
]
return impl
def transcribe(audio_path):
tokenizer = get_tokenizer(multilingual=False)
def transcribe(whisper_model, audio_path):
if whisper_model is None:
raise gr.Error("Whisper model not loaded")
number_tokens = [
i
for i in range(tokenizer.eot)
if all(c in "0123456789" for c in tokenizer.decode([i]).removeprefix(" "))
for i in range(whisper_model["tokenizer"].eot)
if all(c in "0123456789" for c in whisper_model["tokenizer"].decode([i]).removeprefix(" "))
]
result = whisper_model.transcribe(audio_path, suppress_tokens=[-1] + number_tokens, word_timestamps=True)
result = whisper_model["model"].transcribe(audio_path, suppress_tokens=[-1] + number_tokens, word_timestamps=True)
words = [word_info for segment in result["segments"] for word_info in segment["words"]]
transcript = result["text"]
@ -69,12 +70,12 @@ def transcribe(audio_path):
transcript_with_end_time = " ".join([f"{word['word']} {word['end']}" for word in words])
choices = [f"{word['start']} {word['word']} {word['end']}" for word in words]
edit_from_word = gr.Dropdown(label="First word to edit", value=choices[0], choices=choices, interactive=True)
edit_to_word = gr.Dropdown(label="Last word to edit", value=choices[-1], choices=choices, interactive=True)
return [
transcript, transcript_with_start_time, transcript_with_end_time,
edit_from_word, edit_to_word, words
gr.Dropdown(value=choices[0], choices=choices, interactive=True), # edit_from_word
gr.Dropdown(value=choices[-1], choices=choices, interactive=True), # edit_to_word
words
]
@ -86,11 +87,16 @@ def get_output_audio(audio_tensors, codec_audio_sr):
return buffer.read()
def run(left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature,
def run(voicecraft_model, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature,
stop_repetition, sample_batch_size, kvcache, silence_tokens,
audio_path, word_info, transcript, smart_transcript,
mode, prompt_end_time, edit_start_time, edit_end_time,
split_text, selected_sentence, previous_audio_tensors):
if voicecraft_model is None:
raise gr.Error("VoiceCraft model not loaded")
if smart_transcript and (word_info is None):
raise gr.Error("Can't use smart transcript: whisper transcript not found")
if mode == "Long TTS":
if split_text == "Newline":
sentences = transcript.split('\n')
@ -104,6 +110,7 @@ def run(left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, tempe
else:
sentences = [transcript.replace("\n", " ")]
device = "cuda" if torch.cuda.is_available() else "cpu"
info = torchaudio.info(audio_path)
audio_dur = info.num_frames / info.sample_rate
@ -116,7 +123,7 @@ def run(left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, tempe
if mode != "Edit":
from inference_tts_scale import inference_one_sample
if smart_transcript:
if smart_transcript:
target_transcript = ""
for word in word_info:
if word["end"] < prompt_end_time:
@ -175,8 +182,7 @@ def run(left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, tempe
if mode != "Rerun":
output_audio = get_output_audio(audio_tensors, codec_audio_sr)
sentences = [f"{idx}: {text}" for idx, text in enumerate(sentences)]
component = gr.Dropdown(label="Sentence", choices=sentences, value=sentences[0],
info="Select sentence you want to regenerate")
component = gr.Dropdown(choices=sentences, value=sentences[0])
return output_audio, inference_transcript, component, audio_tensors
else:
previous_audio_tensors[selected_sentence_idx] = audio_tensors[0]
@ -185,29 +191,25 @@ def run(left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, tempe
return output_audio, inference_transcript, sentence_audio, previous_audio_tensors
def update_input_audio(prompt_end_time, edit_start_time, edit_end_time):
def impl(audio_path):
info = torchaudio.info(audio_path)
max_time = round(info.num_frames / info.sample_rate, 2)
return [
prompt_end_time.update(maximum=max_time, value=max_time),
edit_start_time.update(maximum=max_time, value=0),
edit_end_time.update(maximum=max_time, value=max_time),
]
return impl
def update_input_audio(audio_path):
info = torchaudio.info(audio_path)
max_time = round(info.num_frames / info.sample_rate, 2)
return [
gr.Slider(maximum=max_time, value=max_time),
gr.Slider(maximum=max_time, value=0),
gr.Slider(maximum=max_time, value=max_time),
]
def change_mode(prompt_end_time, split_text, edit_word_mode, segment_control, precise_segment_control, long_tts_controls):
def impl(mode):
return [
prompt_end_time.update(visible=mode != "Edit"),
split_text.update(visible=mode == "Long TTS"),
edit_word_mode.update(visible=mode == "Edit"),
segment_control.update(visible=mode == "Edit"),
precise_segment_control.update(visible=mode == "Edit"),
long_tts_controls.update(visible=mode == "Long TTS"),
]
return impl
def change_mode(mode):
return [
gr.Slider(visible=mode != "Edit"),
gr.Radio(visible=mode == "Long TTS"),
gr.Radio(visible=mode == "Edit"),
gr.Row(visible=mode == "Edit"),
gr.Accordion(visible=mode == "Edit"),
gr.Group(visible=mode == "Long TTS"),
]
def load_sentence(selected_sentence, codec_audio_sr, audio_tensors):
@ -218,28 +220,27 @@ def load_sentence(selected_sentence, codec_audio_sr, audio_tensors):
return get_output_audio([audio_tensors[selected_sentence_idx]], codec_audio_sr)
def update_bound_word(is_first_word, edit_time):
def impl(selected_word, edit_word_mode):
word_start_time = float(selected_word.split(' ')[0])
word_end_time = float(selected_word.split(' ')[-1])
if edit_word_mode == "Replace half":
bound_time = (word_start_time + word_end_time) / 2
elif is_first_word:
bound_time = word_start_time
else:
bound_time = word_end_time
def update_bound_word(is_first_word, selected_word, edit_word_mode):
if selected_word is None:
return None
return edit_time.update(value=bound_time)
return impl
word_start_time = float(selected_word.split(' ')[0])
word_end_time = float(selected_word.split(' ')[-1])
if edit_word_mode == "Replace half":
bound_time = (word_start_time + word_end_time) / 2
elif is_first_word:
bound_time = word_start_time
else:
bound_time = word_end_time
return bound_time
def update_bound_words(edit_start_time, edit_end_time):
def impl(from_selected_word, to_selected_word, edit_word_mode):
return [
update_bound_word(True, edit_start_time)(from_selected_word, edit_word_mode),
update_bound_word(True, edit_end_time)(to_selected_word, edit_word_mode),
]
return impl
def update_bound_words(from_selected_word, to_selected_word, edit_word_mode):
return [
update_bound_word(True, from_selected_word, edit_word_mode),
update_bound_word(False, to_selected_word, edit_word_mode),
]
smart_transcript_info = """
@ -251,6 +252,7 @@ If disabled, you should write the target transcript yourself:</br>
- In Long TTS select split by newline (<b>SENTENCE SPLIT WON'T WORK</b>) and start each line with a prompt transcript.</br>
- In Edit mode write full prompt</br>
"""
demo_text = {
"TTS": {
"smart": "I cannot believe that the same model can also do text to speech synthesis as well!",
@ -269,7 +271,9 @@ demo_text = {
"But when I had approached so near to them, the common If some sentences sound odd, just rerun TTS on them, no need to generate the whole text again!"
}
}
all_demo_texts = {vv for k, v in demo_text.items() for kk, vv in v.items()}
demo_words = [
'0.0 But 0.12', '0.12 when 0.26', '0.26 I 0.44', '0.44 had 0.6', '0.6 approached 0.94', '0.94 so 1.42',
'1.42 near 1.78', '1.78 to 2.02', '2.02 them, 2.24', '2.52 the 2.58', '2.58 common 2.9', '2.9 object, 3.3',
@ -278,19 +282,17 @@ demo_words = [
]
def update_demo(transcript, edit_from_word, edit_to_word, prompt_end_time):
def impl(mode, smart_transcript, edit_word_mode):
if transcript.value not in all_demo_texts:
return [transcript, edit_from_word, edit_to_word, prompt_end_time]
replace_half = edit_word_mode == "Replace half"
return [
transcript.update(value=demo_text[mode]["smart" if smart_transcript else "regular"]),
edit_from_word.update(value="0.26 I 0.44" if replace_half else "0.44 had 0.6"),
edit_to_word.update(value="3.72 which 3.78" if replace_half else "2.9 object, 3.3"),
prompt_end_time.update(value=3.01),
]
return impl
def update_demo(mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word, prompt_end_time):
if transcript not in all_demo_texts:
return transcript, edit_from_word, edit_to_word, prompt_end_time
replace_half = edit_word_mode == "Replace half"
return [
demo_text[mode]["smart" if smart_transcript else "regular"],
"0.26 I 0.44" if replace_half else "0.44 had 0.6",
"3.72 which 3.78" if replace_half else "2.9 object, 3.3",
3.01,
]
with gr.Blocks() as app:
@ -302,7 +304,7 @@ with gr.Blocks() as app:
with gr.Row():
voicecraft_model_choice = gr.Radio(label="VoiceCraft model", value="giga830M", choices=["giga330M", "giga830M"])
whisper_model_choice = gr.Radio(label="Whisper model", value="base.en",
choices=["tiny.en", "base.en", "small.en", "medium.en", "large"])
choices=[None, "tiny.en", "base.en", "small.en", "medium.en", "large"])
with gr.Row():
with gr.Column(scale=2):
@ -315,7 +317,7 @@ with gr.Blocks() as app:
with gr.Accordion("Word end time", open=False):
transcript_with_end_time = gr.Textbox(label="End time", lines=5, interactive=False, info="End time after each word")
transcribe_btn = gr.Button(value="Transcribe", interactive=False)
transcribe_btn = gr.Button(value="Transcribe")
with gr.Column(scale=3):
with gr.Group():
@ -338,7 +340,7 @@ with gr.Blocks() as app:
edit_start_time = gr.Slider(label="Edit from time", minimum=0, maximum=60, step=0.01, value=0)
edit_end_time = gr.Slider(label="Edit to time", minimum=0, maximum=60, step=0.01, value=60)
run_btn = gr.Button(value="Run", interactive=False)
run_btn = gr.Button(value="Run")
with gr.Column(scale=2):
output_audio = gr.Audio(label="Output Audio")
@ -349,7 +351,7 @@ with gr.Blocks() as app:
sentence_selector = gr.Dropdown(label="Sentence", value=None,
info="Select sentence you want to regenerate")
sentence_audio = gr.Audio(label="Sentence Audio", scale=2)
rerun_btn = gr.Button(value="Rerun", interactive=False)
rerun_btn = gr.Button(value="Rerun")
with gr.Row():
with gr.Accordion("VoiceCraft config", open=False):
@ -369,34 +371,40 @@ with gr.Blocks() as app:
silence_tokens = gr.Textbox(label="silence tokens", value="[1388,1898,131]")
whisper_model = gr.State()
voicecraft_model = gr.State()
audio_tensors = gr.State()
word_info = gr.State()
mode.change(fn=update_demo(transcript, edit_from_word, edit_to_word, prompt_end_time),
inputs=[mode, smart_transcript, edit_word_mode],
mode.change(fn=update_demo,
inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word, prompt_end_time],
outputs=[transcript, edit_from_word, edit_to_word, prompt_end_time])
edit_word_mode.change(fn=update_demo(transcript, edit_from_word, edit_to_word, prompt_end_time),
inputs=[mode, smart_transcript, edit_word_mode],
edit_word_mode.change(fn=update_demo,
inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word, prompt_end_time],
outputs=[transcript, edit_from_word, edit_to_word, prompt_end_time])
smart_transcript.change(fn=update_demo(transcript, edit_from_word, edit_to_word, prompt_end_time),
inputs=[mode, smart_transcript, edit_word_mode],
smart_transcript.change(fn=update_demo,
inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word, prompt_end_time],
outputs=[transcript, edit_from_word, edit_to_word, prompt_end_time])
load_models_btn.click(fn=load_models(input_audio, transcribe_btn, run_btn, rerun_btn),
load_models_btn.click(fn=load_models,
inputs=[whisper_model_choice, voicecraft_model_choice],
outputs=[input_audio, transcribe_btn, run_btn, rerun_btn])
input_audio.change(fn=update_input_audio(prompt_end_time, edit_start_time, edit_end_time),
inputs=[input_audio], outputs=[prompt_end_time, edit_start_time, edit_end_time])
transcribe_btn.click(fn=transcribe, inputs=[input_audio],
outputs=[whisper_model, voicecraft_model, input_audio])
input_audio.change(fn=update_input_audio,
inputs=[input_audio],
outputs=[prompt_end_time, edit_start_time, edit_end_time])
transcribe_btn.click(fn=transcribe,
inputs=[whisper_model, input_audio],
outputs=[original_transcript, transcript_with_start_time, transcript_with_end_time, edit_from_word, edit_to_word, word_info])
mode.change(fn=change_mode(prompt_end_time, split_text, edit_word_mode, segment_control, precise_segment_control, long_tts_controls),
inputs=[mode], outputs=[prompt_end_time, split_text, edit_word_mode, segment_control, precise_segment_control, long_tts_controls])
mode.change(fn=change_mode,
inputs=[mode],
outputs=[prompt_end_time, split_text, edit_word_mode, segment_control, precise_segment_control, long_tts_controls])
run_btn.click(fn=run,
inputs=[
left_margin, right_margin,
voicecraft_model, left_margin, right_margin,
codec_audio_sr, codec_sr,
top_k, top_p, temperature,
stop_repetition, sample_batch_size,
@ -405,14 +413,14 @@ with gr.Blocks() as app:
mode, prompt_end_time, edit_start_time, edit_end_time,
split_text, sentence_selector, audio_tensors
],
outputs=[
output_audio, inference_transcript, sentence_selector, audio_tensors
])
outputs=[output_audio, inference_transcript, sentence_selector, audio_tensors])
sentence_selector.change(fn=load_sentence, inputs=[sentence_selector, codec_audio_sr, audio_tensors], outputs=[sentence_audio])
sentence_selector.change(fn=load_sentence,
inputs=[sentence_selector, codec_audio_sr, audio_tensors],
outputs=[sentence_audio])
rerun_btn.click(fn=run,
inputs=[
left_margin, right_margin,
voicecraft_model, left_margin, right_margin,
codec_audio_sr, codec_sr,
top_k, top_p, temperature,
stop_repetition, sample_batch_size,
@ -421,16 +429,17 @@ with gr.Blocks() as app:
gr.State(value="Rerun"), prompt_end_time, edit_start_time, edit_end_time,
split_text, sentence_selector, audio_tensors
],
outputs=[
output_audio, inference_transcript, sentence_audio, audio_tensors
])
outputs=[output_audio, inference_transcript, sentence_audio, audio_tensors])
edit_word_mode.change(fn=update_bound_words(edit_start_time, edit_end_time),
inputs=[edit_from_word, edit_to_word, edit_word_mode], outputs=[edit_start_time, edit_end_time])
edit_from_word.change(fn=update_bound_word(True, edit_start_time),
inputs=[edit_from_word, edit_word_mode], outputs=[edit_start_time])
edit_to_word.change(fn=update_bound_word(False, edit_end_time),
inputs=[edit_to_word, edit_word_mode], outputs=[edit_end_time])
edit_from_word.change(fn=update_bound_word,
inputs=[gr.State(True), edit_from_word, edit_word_mode],
outputs=[edit_start_time])
edit_to_word.change(fn=update_bound_word,
inputs=[gr.State(False), edit_to_word, edit_word_mode],
outputs=[edit_end_time])
edit_word_mode.change(fn=update_bound_words,
inputs=[edit_from_word, edit_to_word, edit_word_mode],
outputs=[edit_start_time, edit_end_time])
if __name__ == "__main__":