update with whisperx

This commit is contained in:
chenxwh 2024-04-19 10:46:00 +00:00
parent 6e5382584c
commit 9746a1f60c
3 changed files with 113 additions and 49 deletions

4
.gitignore vendored
View File

@ -26,4 +26,6 @@ thumbs.db
src/audiocraft
!/demo/
!/demo/*
!/demo/*
.cog/tmp/*

View File

@ -4,30 +4,21 @@
build:
gpu: true
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
- libgl1-mesa-glx
- libglib2.0-0
- ffmpeg
- espeak-ng
python_version: "3.9.16"
python_version: "3.11"
python_packages:
- torch==2.0.1
- torchaudio==2.0.2
- xformers==0.0.22
- tensorboard==2.16.2
- torch==2.1.0
- torchaudio==2.1.0
- xformers
- phonemizer==3.2.1
- datasets==2.16.0
- torchmetrics==0.11.1
- whisperx==3.1.1
- openai-whisper>=20231117
run:
- curl -O https://repo.anaconda.com/miniconda/Miniconda3-py310_23.3.1-0-Linux-x86_64.sh
- bash Miniconda3-py310_23.3.1-0-Linux-x86_64.sh -b -p /cog/miniconda
- /cog/miniconda/bin/conda init bash
- /bin/bash -c "source /cog/miniconda/bin/activate && conda create -n myenv python=3.9.16 -y"
- /bin/bash -c "source /cog/miniconda/bin/activate && conda activate myenv && conda install -c conda-forge montreal-forced-aligner=2.2.17 openfst=1.8.2 kaldi=5.5.1068 -y"
- /bin/bash -c "source /cog/miniconda/bin/activate && conda activate myenv && mfa model download dictionary english_us_arpa && mfa model download acoustic english_us_arpa"
- export PATH=/cog/miniconda/envs/myenv/bin:$PATH
- git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft
- pip install "pydantic<2.0.0"
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
predict: "predict.py:Predictor"
- mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth" "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth"
predict: "predict.py:Predictor"

View File

@ -2,9 +2,7 @@
# https://github.com/replicate/cog/blob/main/docs/python.md
import os
import stat
import time
import warnings
import random
import getpass
import shutil
@ -12,12 +10,10 @@ import subprocess
import torch
import numpy as np
import torchaudio
from whisper.model import Whisper, ModelDimensions
from whisper.tokenizer import get_tokenizer
from cog import BasePredictor, Input, Path, BaseModel
warnings.filterwarnings("ignore", category=UserWarning)
os.environ["USER"] = getpass.getuser()
from data.tokenizer import (
@ -27,14 +23,12 @@ from data.tokenizer import (
from models import voicecraft
from inference_tts_scale import inference_one_sample
from edit_utils import get_span
from inference_speech_editing_scale import get_mask_interval
from inference_speech_editing_scale import (
inference_one_sample as inference_one_sample_editing,
)
ENV_NAME = "myenv"
MODEL_URL = "https://weights.replicate.delivery/default/pyp1/VoiceCraft.tar"
MODEL_URL = "https://weights.replicate.delivery/default/pyp1/VoiceCraft-models.tar" # all the models are cached and uploaded to replicate.delivery for faster booting
MODEL_CACHE = "model_cache"
@ -43,9 +37,56 @@ class ModelOutput(BaseModel):
generated_audio: Path
class WhisperxAlignModel:
def __init__(self):
from whisperx import load_align_model
self.model, self.metadata = load_align_model(
language_code="en", device="cuda:0"
)
def align(self, segments, audio_path):
from whisperx import align, load_audio
audio = load_audio(audio_path)
return align(
segments,
self.model,
self.metadata,
audio,
device="cuda:0",
return_char_alignments=False,
)["segments"]
class WhisperxModel:
def __init__(self, model_name, align_model: WhisperxAlignModel, device="cuda"):
from whisperx import load_model
# the model weights are cached from Systran/faster-whisper-base.en etc
self.model = load_model(
model_name,
device,
asr_options={
"suppress_numerals": True,
"max_new_tokens": None,
"clip_timestamps": None,
"hallucination_silence_threshold": None,
},
)
self.align_model = align_model
def transcribe(self, audio_path):
segments = self.model.transcribe(audio_path, language="en", batch_size=8)[
"segments"
]
return self.align_model.align(segments, audio_path)
class WhisperModel:
def __init__(self, model_cache, model_name="base.en", device="cuda"):
# the model weights are cached from https://github.com/openai/whisper/blob/ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab/whisper/__init__.py#L17
with open(f"{model_cache}/{model_name}.pt", "rb") as fp:
checkpoint = torch.load(fp, map_location="cpu")
dims = ModelDimensions(**checkpoint["dims"])
@ -105,11 +146,17 @@ class Predictor(BasePredictor):
self.text_tokenizer = TextTokenizer(backend="espeak")
self.audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=self.device)
self.transcribe_models = {
self.transcribe_models_whisper = {
k: WhisperModel(MODEL_CACHE, k, self.device)
for k in ["base.en", "small.en", "medium.en"]
}
align_model = WhisperxAlignModel()
self.transcribe_models_whisperx = {
k: WhisperxModel(f"{MODEL_CACHE}/whisperx_{k.split('.')[0]}", align_model)
for k in ["base.en", "small.en", "medium.en"]
}
def predict(
self,
task: str = Input(
@ -135,9 +182,16 @@ class Predictor(BasePredictor):
default="",
),
whisper_model: str = Input(
description="If orig_transcript is not provided above, choose a Whisper model. Inaccurate transcription may lead to error TTS or speech editing. You can modify the generated transcript and provide it directly to ",
choices=["base.en", "small.en", "medium.en"],
default="base.en",
description="If orig_transcript is not provided above, choose a Whisper or WhisperX model. WhisperX model contains extra alignment steps. Inaccurate transcription may lead to error TTS or speech editing. You can modify the generated transcript and provide it directly to ",
choices=[
"whisper-base.en",
"whisper-small.en",
"whisper-medium.en",
"whisperx-base.en",
"whisperx-small.en",
"whisperx-medium.en",
],
default="whisper-base.en",
),
target_transcript: str = Input(
description="Transcript of the target audio file",
@ -188,8 +242,19 @@ class Predictor(BasePredictor):
seed_everything(seed)
segments = self.transcribe_models[whisper_model].transcribe(str(orig_audio))
whisper_model, whisper_model_size = whisper_model.split("-")
if whisper_model == "whisper":
segments = self.transcribe_models_whisper[whisper_model_size].transcribe(
str(orig_audio)
)
else:
segments = self.transcribe_models_whisperx[whisper_model_size].transcribe(
str(orig_audio)
)
state = get_transcribe_state(segments)
whisper_transcript = state["transcript"].strip()
if len(orig_transcript.strip()) == 0:
@ -204,25 +269,8 @@ class Predictor(BasePredictor):
os.makedirs(temp_folder)
filename = "orig_audio"
shutil.copy(orig_audio, f"{temp_folder}/{filename}.wav")
audio_fn = str(orig_audio)
with open(f"{temp_folder}/{filename}.txt", "w") as f:
f.write(orig_transcript)
# run MFA to get the alignment
align_temp = f"{temp_folder}/mfa_alignments"
command = f'/bin/bash -c "source /cog/miniconda/bin/activate && conda activate {ENV_NAME} && mfa align -v --clean -j 1 --output_format csv {temp_folder} english_us_arpa english_us_arpa {align_temp}"'
try:
subprocess.run(command, shell=True, check=True)
except subprocess.CalledProcessError as e:
print("Error:", e)
raise RuntimeError("Error running Alignment")
print("Alignment done!")
align_fn = f"{align_temp}/{filename}.csv"
audio_fn = f"{temp_folder}/{filename}.wav"
info = torchaudio.info(audio_fn)
audio_dur = info.num_frames / info.sample_rate
@ -262,7 +310,10 @@ class Predictor(BasePredictor):
new_span_save = new_span
orig_span_save = ",".join([str(item) for item in orig_span_save])
new_span_save = ",".join([str(item) for item in new_span_save])
start, end = get_mask_interval(align_fn, orig_span_save, edit_type)
start, end = get_mask_interval_from_word_bounds(
state["word_bounds"], orig_span_save, edit_type
)
# span in codec frames
morphed_span = (
@ -359,3 +410,23 @@ def find_closest_cut_off_word(word_bounds, cut_off_sec):
break
return i
def get_mask_interval_from_word_bounds(word_bounds, word_span_ind, editType):
tmp = word_span_ind.split(",")
s, e = int(tmp[0]), int(tmp[-1])
start = None
for j, item in enumerate(word_bounds):
if j == s:
if editType == "insertion":
start = float(item["end"])
else:
start = float(item["start"])
if j == e:
if editType == "insertion":
end = float(item["start"])
else:
end = float(item["end"])
assert start != None
break
return (start, end)