whisperx added

This commit is contained in:
Stepan Zuev 2024-04-04 22:22:27 +03:00
parent 1a219cf6da
commit 3d3f32ba7e
3 changed files with 170 additions and 67 deletions

View File

@ -111,7 +111,7 @@ It is ready to use on [default url](http://127.0.0.1:7860).
6. (optionally) Rerun part-by-part in Long TTS mode
### Some features
Smart transcript: write only what you want to generate, but don't work if you edit original transcript
Smart transcript: write only what you want to generate
TTS mode: Zero-shot TTS

View File

@ -10,9 +10,16 @@ import os
import io
import numpy as np
import random
import uuid
whisper_model, voicecraft_model = None, None
TMP_PATH = "./demo/temp"
device = "cuda" if torch.cuda.is_available() else "cpu"
whisper_model, align_model, voicecraft_model = None, None, None
def get_random_string():
return "".join(str(uuid.uuid4()).split("-"))
def seed_everything(seed):
@ -26,22 +33,63 @@ def seed_everything(seed):
torch.backends.cudnn.deterministic = True
def load_models(whisper_model_choice, voicecraft_model_choice):
global whisper_model, voicecraft_model
class WhisperxAlignModel:
def __init__(self):
from whisperx import load_align_model
self.model, self.metadata = load_align_model(language_code="en", device=device)
def align(self, segments, audio_path):
from whisperx import align, load_audio
audio = load_audio(audio_path)
return align(segments, self.model, self.metadata, audio, device, return_char_alignments=False)["segments"]
class WhisperModel:
def __init__(self, model_name):
from whisper import load_model
self.model = load_model(model_name, device)
if whisper_model_choice is not None:
import whisper
from whisper.tokenizer import get_tokenizer
whisper_model = {
"model": whisper.load_model(whisper_model_choice),
"tokenizer": get_tokenizer(multilingual=False)
}
tokenizer = get_tokenizer(multilingual=False)
self.supress_tokens = [-1] + [
i
for i in range(tokenizer.eot)
if all(c in "0123456789" for c in tokenizer.decode([i]).removeprefix(" "))
]
def transcribe(self, audio_path):
return self.model.transcribe(audio_path, suppress_tokens=self.supress_tokens, word_timestamps=True)["segments"]
class WhisperxModel:
def __init__(self, model_name, align_model: WhisperxAlignModel):
from whisperx import load_model
self.model = load_model(model_name, device, asr_options={"suppress_numerals": True})
self.align_model = align_model
def transcribe(self, audio_path):
segments = self.model.transcribe(audio_path, batch_size=8)["segments"]
return self.align_model.align(segments, audio_path)
def load_models(whisper_backend_name, whisper_model_name, alignment_model_name, voicecraft_model_name):
global transcribe_model, align_model, voicecraft_model
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = "cuda" if torch.cuda.is_available() else "cpu"
if alignment_model_name is not None:
align_model = WhisperxAlignModel()
if whisper_model_name is not None:
if whisper_backend_name == "whisper":
transcribe_model = WhisperModel(whisper_model_name)
else:
if align_model is None:
raise gr.Error("Align model required for whisperx backend")
transcribe_model = WhisperxModel(whisper_model_name, align_model)
voicecraft_name = f"{voicecraft_model_choice}.pth"
voicecraft_name = f"{voicecraft_model_name}.pth"
ckpt_fn = f"./pretrained_models/{voicecraft_name}"
encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th"
if not os.path.exists(ckpt_fn):
@ -66,31 +114,78 @@ def load_models(whisper_model_choice, voicecraft_model_choice):
return gr.Accordion()
def get_transcribe_state(segments):
words_info = [word_info for segment in segments for word_info in segment["words"]]
return {
"segments": segments,
"transcript": " ".join([segment["text"] for segment in segments]),
"words_info": words_info,
"transcript_with_start_time": " ".join([f"{word['start']} {word['word']}" for word in words_info]),
"transcript_with_end_time": " ".join([f"{word['word']} {word['end']}" for word in words_info]),
"word_bounds": [f"{word['start']} {word['word']} {word['end']}" for word in words_info]
}
def transcribe(seed, audio_path):
if whisper_model is None:
raise gr.Error("Whisper model not loaded")
if transcribe_model is None:
raise gr.Error("Transcription model not loaded")
seed_everything(seed)
number_tokens = [
i
for i in range(whisper_model["tokenizer"].eot)
if all(c in "0123456789" for c in whisper_model["tokenizer"].decode([i]).removeprefix(" "))
]
result = whisper_model["model"].transcribe(audio_path, suppress_tokens=[-1] + number_tokens, word_timestamps=True)
words = [word_info for segment in result["segments"] for word_info in segment["words"]]
transcript = result["text"]
transcript_with_start_time = " ".join([f"{word['start']} {word['word']}" for word in words])
transcript_with_end_time = " ".join([f"{word['word']} {word['end']}" for word in words])
choices = [f"{word['start']} {word['word']} {word['end']}" for word in words]
segments = transcribe_model.transcribe(audio_path)
state = get_transcribe_state(segments)
return [
transcript, transcript_with_start_time, transcript_with_end_time,
gr.Dropdown(value=choices[-1], choices=choices, interactive=True), # prompt_to_word
gr.Dropdown(value=choices[0], choices=choices, interactive=True), # edit_from_word
gr.Dropdown(value=choices[-1], choices=choices, interactive=True), # edit_to_word
words
state["transcript"], state["transcript_with_start_time"], state["transcript_with_end_time"],
gr.Dropdown(value=state["word_bounds"][-1], choices=state["word_bounds"], interactive=True), # prompt_to_word
gr.Dropdown(value=state["word_bounds"][0], choices=state["word_bounds"], interactive=True), # edit_from_word
gr.Dropdown(value=state["word_bounds"][-1], choices=state["word_bounds"], interactive=True), # edit_to_word
state
]
def align_segments(transcript, audio_path):
from aeneas.executetask import ExecuteTask
from aeneas.task import Task
import json
config_string = 'task_language=eng|os_task_file_format=json|is_text_type=plain'
tmp_transcript_path = os.path.join(TMP_PATH, f"{get_random_string()}.txt")
tmp_sync_map_path = os.path.join(TMP_PATH, f"{get_random_string()}.json")
with open(tmp_transcript_path, "w") as f:
f.write(transcript)
task = Task(config_string=config_string)
task.audio_file_path_absolute = os.path.abspath(audio_path)
task.text_file_path_absolute = os.path.abspath(tmp_transcript_path)
task.sync_map_file_path_absolute = os.path.abspath(tmp_sync_map_path)
ExecuteTask(task).execute()
task.output_sync_map_file()
with open(tmp_sync_map_path, "r") as f:
return json.load(f)
def align(seed, transcript, audio_path):
if align_model is None:
raise gr.Error("Align model not loaded")
seed_everything(seed)
fragments = align_segments(transcript, audio_path)
segments = [{
"start": float(fragment["begin"]),
"end": float(fragment["end"]),
"text": " ".join(fragment["lines"])
} for fragment in fragments["fragments"]]
segments = align_model.align(segments, audio_path)
state = get_transcribe_state(segments)
print(state)
return [
state["transcript_with_start_time"], state["transcript_with_end_time"],
gr.Dropdown(value=state["word_bounds"][-1], choices=state["word_bounds"], interactive=True), # prompt_to_word
gr.Dropdown(value=state["word_bounds"][0], choices=state["word_bounds"], interactive=True), # edit_from_word
gr.Dropdown(value=state["word_bounds"][-1], choices=state["word_bounds"], interactive=True), # edit_to_word
state
]
@ -104,12 +199,12 @@ def get_output_audio(audio_tensors, codec_audio_sr):
def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature,
stop_repetition, sample_batch_size, kvcache, silence_tokens,
audio_path, word_info, transcript, smart_transcript,
audio_path, transcribe_state, transcript, smart_transcript,
mode, prompt_end_time, edit_start_time, edit_end_time,
split_text, selected_sentence, previous_audio_tensors):
if voicecraft_model is None:
raise gr.Error("VoiceCraft model not loaded")
if smart_transcript and (word_info is None):
if smart_transcript and (transcribe_state is None):
raise gr.Error("Can't use smart transcript: whisper transcript not found")
seed_everything(seed)
@ -126,7 +221,6 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p,
else:
sentences = [transcript.replace("\n", " ")]
device = "cuda" if torch.cuda.is_available() else "cpu"
info = torchaudio.info(audio_path)
audio_dur = info.num_frames / info.sample_rate
@ -141,7 +235,7 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p,
if smart_transcript:
target_transcript = ""
for word in word_info:
for word in transcribe_state["words_info"]:
if word["end"] < prompt_end_time:
target_transcript += word["word"]
elif (word["start"] + word["end"]) / 2 < prompt_end_time:
@ -169,13 +263,13 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p,
if smart_transcript:
target_transcript = ""
for word in word_info:
for word in transcribe_state["words_info"]:
if word["start"] < edit_start_time:
target_transcript += word["word"]
else:
break
target_transcript += f" {sentence}"
for word in word_info:
for word in transcribe_state["words_info"]:
if word["end"] > edit_end_time:
target_transcript += word["word"]
else:
@ -296,25 +390,25 @@ demo_text = {
all_demo_texts = {vv for k, v in demo_text.items() for kk, vv in v.items()}
demo_words = [
'0.0 But 0.12', '0.12 when 0.26', '0.26 I 0.44', '0.44 had 0.6', '0.6 approached 0.94', '0.94 so 1.42',
'1.42 near 1.78', '1.78 to 2.02', '2.02 them, 2.24', '2.52 the 2.58', '2.58 common 2.9', '2.9 object, 3.3',
'3.72 which 3.78', '3.78 the 3.98', '3.98 sense 4.18', '4.18 deceives, 4.88', '5.06 lost 5.26', '5.26 not 5.74',
'5.74 by 6.08', '6.08 distance 6.36', '6.36 any 6.92', '6.92 of 7.12', '7.12 its 7.26', '7.26 marks. 7.54'
'0.029 But 0.149', '0.189 when 0.33', '0.43 I 0.49', '0.53 had 0.65', '0.711 approached 1.152', '1.352 so 1.593',
'1.693 near 1.933', '1.994 to 2.074', '2.134 them, 2.354', '2.535 the 2.655', '2.695 common 3.016', '3.196 object, 3.577',
'3.717 which 3.898', '3.958 the 4.058', '4.098 sense 4.359', '4.419 deceives, 4.92', '5.101 lost 5.481', '5.682 not 5.963',
'6.043 by 6.183', '6.223 distance 6.644', '6.905 any 7.065', '7.125 of 7.185', '7.245 its 7.346', '7.406 marks. 7.727'
]
demo_word_info = [
{'word': ' But', 'start': 0.0, 'end': 0.12}, {'word': ' when', 'start': 0.12, 'end': 0.26},
{'word': ' I', 'start': 0.26, 'end': 0.44}, {'word': ' had', 'start': 0.44, 'end': 0.6},
{'word': ' approached', 'start': 0.6, 'end': 0.94}, {'word': ' so', 'start': 0.94, 'end': 1.42},
{'word': ' near', 'start': 1.42, 'end': 1.78}, {'word': ' to', 'start': 1.78, 'end': 2.02},
{'word': ' them,', 'start': 2.02, 'end': 2.24}, {'word': ' the', 'start': 2.52, 'end': 2.58},
{'word': ' common', 'start': 2.58, 'end': 2.9}, {'word': ' object,', 'start': 2.9, 'end': 3.3},
{'word': ' which', 'start': 3.72, 'end': 3.78}, {'word': ' the', 'start': 3.78, 'end': 3.98},
{'word': ' sense', 'start': 3.98, 'end': 4.18}, {'word': ' deceives,', 'start': 4.18, 'end': 4.88},
{'word': ' lost', 'start': 5.06, 'end': 5.26}, {'word': ' not', 'start': 5.26, 'end': 5.74},
{'word': ' by', 'start': 5.74, 'end': 6.08}, {'word': ' distance', 'start': 6.08, 'end': 6.36},
{'word': ' any', 'start': 6.36, 'end': 6.92}, {'word': ' of', 'start': 6.92, 'end': 7.12},
{'word': ' its', 'start': 7.12, 'end': 7.26}, {'word': ' marks.', 'start': 7.26, 'end': 7.54}
demo_words_info = [
{'word': 'But', 'start': 0.029, 'end': 0.149, 'score': 0.834}, {'word': 'when', 'start': 0.189, 'end': 0.33, 'score': 0.879},
{'word': 'I', 'start': 0.43, 'end': 0.49, 'score': 0.984}, {'word': 'had', 'start': 0.53, 'end': 0.65, 'score': 0.998},
{'word': 'approached', 'start': 0.711, 'end': 1.152, 'score': 0.822}, {'word': 'so', 'start': 1.352, 'end': 1.593, 'score': 0.822},
{'word': 'near', 'start': 1.693, 'end': 1.933, 'score': 0.752}, {'word': 'to', 'start': 1.994, 'end': 2.074, 'score': 0.924},
{'word': 'them,', 'start': 2.134, 'end': 2.354, 'score': 0.914}, {'word': 'the', 'start': 2.535, 'end': 2.655, 'score': 0.818},
{'word': 'common', 'start': 2.695, 'end': 3.016, 'score': 0.971}, {'word': 'object,', 'start': 3.196, 'end': 3.577, 'score': 0.823},
{'word': 'which', 'start': 3.717, 'end': 3.898, 'score': 0.701}, {'word': 'the', 'start': 3.958, 'end': 4.058, 'score': 0.798},
{'word': 'sense', 'start': 4.098, 'end': 4.359, 'score': 0.797}, {'word': 'deceives,', 'start': 4.419, 'end': 4.92, 'score': 0.802},
{'word': 'lost', 'start': 5.101, 'end': 5.481, 'score': 0.71}, {'word': 'not', 'start': 5.682, 'end': 5.963, 'score': 0.781},
{'word': 'by', 'start': 6.043, 'end': 6.183, 'score': 0.834}, {'word': 'distance', 'start': 6.223, 'end': 6.644, 'score': 0.899},
{'word': 'any', 'start': 6.905, 'end': 7.065, 'score': 0.893}, {'word': 'of', 'start': 7.125, 'end': 7.185, 'score': 0.772},
{'word': 'its', 'start': 7.245, 'end': 7.346, 'score': 0.778}, {'word': 'marks.', 'start': 7.406, 'end': 7.727, 'score': 0.955}
]
@ -342,21 +436,24 @@ with gr.Blocks() as app:
with gr.Accordion("Select models", open=False) as models_selector:
with gr.Row():
voicecraft_model_choice = gr.Radio(label="VoiceCraft model", value="giga830M", choices=["giga330M", "giga830M"])
whisper_backend_choice = gr.Radio(label="Whisper backend", value="whisperX", choices=["whisper", "whisperX"])
whisper_model_choice = gr.Radio(label="Whisper model", value="base.en",
choices=[None, "tiny.en", "base.en", "small.en", "medium.en", "large"])
choices=[None, "base.en", "small.en", "medium.en", "large"])
align_model_choice = gr.Radio(label="Forced alignment model", value="whisperX", choices=[None, "whisperX"])
with gr.Row():
with gr.Column(scale=2):
input_audio = gr.Audio(value="./demo/84_121550_000074_000000.wav", label="Input Audio", type="filepath")
with gr.Group():
original_transcript = gr.Textbox(label="Original transcript", lines=5, value=demo_original_transcript, interactive=False,
info="Use whisper model to get the transcript. Fix it if necessary.")
original_transcript = gr.Textbox(label="Original transcript", lines=5, value=demo_original_transcript,
info="Use whisper model to get the transcript. Fix and align it if necessary.")
with gr.Accordion("Word start time", open=False):
transcript_with_start_time = gr.Textbox(label="Start time", lines=5, interactive=False, info="Start time before each word")
with gr.Accordion("Word end time", open=False):
transcript_with_end_time = gr.Textbox(label="End time", lines=5, interactive=False, info="End time after each word")
transcribe_btn = gr.Button(value="Transcribe")
align_btn = gr.Button(value="Align")
with gr.Column(scale=3):
with gr.Group():
@ -375,15 +472,15 @@ with gr.Blocks() as app:
with gr.Group() as tts_mode_controls:
prompt_to_word = gr.Dropdown(label="Last word in prompt", choices=demo_words, value=demo_words[10], interactive=True)
prompt_end_time = gr.Slider(label="Prompt end time", minimum=0, maximum=7.93, step=0.01, value=3.01)
prompt_end_time = gr.Slider(label="Prompt end time", minimum=0, maximum=7.93, step=0.001, value=3.016)
with gr.Group(visible=False) as edit_mode_controls:
with gr.Row():
edit_from_word = gr.Dropdown(label="First word to edit", choices=demo_words, value=demo_words[2], interactive=True)
edit_to_word = gr.Dropdown(label="Last word to edit", choices=demo_words, value=demo_words[12], interactive=True)
with gr.Row():
edit_start_time = gr.Slider(label="Edit from time", minimum=0, maximum=7.93, step=0.01, value=0.35)
edit_end_time = gr.Slider(label="Edit to time", minimum=0, maximum=7.93, step=0.01, value=3.75)
edit_start_time = gr.Slider(label="Edit from time", minimum=0, maximum=7.93, step=0.001, value=0.46)
edit_end_time = gr.Slider(label="Edit to time", minimum=0, maximum=7.93, step=0.001, value=3.808)
run_btn = gr.Button(value="Run")
@ -418,7 +515,7 @@ with gr.Blocks() as app:
audio_tensors = gr.State()
word_info = gr.State(value=demo_word_info)
transcribe_state = gr.State(value={"words_info": demo_words_info})
mode.change(fn=update_demo,
@ -432,7 +529,7 @@ with gr.Blocks() as app:
outputs=[transcript, edit_from_word, edit_to_word])
load_models_btn.click(fn=load_models,
inputs=[whisper_model_choice, voicecraft_model_choice],
inputs=[whisper_backend_choice, whisper_model_choice, align_model_choice, voicecraft_model_choice],
outputs=[models_selector])
input_audio.upload(fn=update_input_audio,
@ -441,7 +538,11 @@ with gr.Blocks() as app:
transcribe_btn.click(fn=transcribe,
inputs=[seed, input_audio],
outputs=[original_transcript, transcript_with_start_time, transcript_with_end_time,
prompt_to_word, edit_from_word, edit_to_word, word_info])
prompt_to_word, edit_from_word, edit_to_word, transcribe_state])
align_btn.click(fn=align,
inputs=[seed, original_transcript, input_audio],
outputs=[transcript_with_start_time, transcript_with_end_time,
prompt_to_word, edit_from_word, edit_to_word, transcribe_state])
mode.change(fn=change_mode,
inputs=[mode],
@ -454,7 +555,7 @@ with gr.Blocks() as app:
top_k, top_p, temperature,
stop_repetition, sample_batch_size,
kvcache, silence_tokens,
input_audio, word_info, transcript, smart_transcript,
input_audio, transcribe_state, transcript, smart_transcript,
mode, prompt_end_time, edit_start_time, edit_end_time,
split_text, sentence_selector, audio_tensors
],
@ -470,7 +571,7 @@ with gr.Blocks() as app:
top_k, top_p, temperature,
stop_repetition, sample_batch_size,
kvcache, silence_tokens,
input_audio, word_info, transcript, smart_transcript,
input_audio, transcribe_state, transcript, smart_transcript,
gr.State(value="Rerun"), prompt_end_time, edit_start_time, edit_end_time,
split_text, sentence_selector, audio_tensors
],

View File

@ -1,3 +1,5 @@
gradio==3.50.2
nltk>=3.8.1
openai-whisper>=20231117
openai-whisper>=20231117
aeneas>=1.7.3.0
whisperx>=3.1.1