update with whisperx
This commit is contained in:
parent
6e5382584c
commit
9746a1f60c
|
@ -26,4 +26,6 @@ thumbs.db
|
|||
src/audiocraft
|
||||
|
||||
!/demo/
|
||||
!/demo/*
|
||||
!/demo/*
|
||||
|
||||
.cog/tmp/*
|
25
cog.yaml
25
cog.yaml
|
@ -4,30 +4,21 @@
|
|||
build:
|
||||
gpu: true
|
||||
system_packages:
|
||||
- "libgl1-mesa-glx"
|
||||
- "libglib2.0-0"
|
||||
- libgl1-mesa-glx
|
||||
- libglib2.0-0
|
||||
- ffmpeg
|
||||
- espeak-ng
|
||||
python_version: "3.9.16"
|
||||
python_version: "3.11"
|
||||
python_packages:
|
||||
- torch==2.0.1
|
||||
- torchaudio==2.0.2
|
||||
- xformers==0.0.22
|
||||
- tensorboard==2.16.2
|
||||
- torch==2.1.0
|
||||
- torchaudio==2.1.0
|
||||
- xformers
|
||||
- phonemizer==3.2.1
|
||||
- datasets==2.16.0
|
||||
- torchmetrics==0.11.1
|
||||
- whisperx==3.1.1
|
||||
- openai-whisper>=20231117
|
||||
run:
|
||||
- curl -O https://repo.anaconda.com/miniconda/Miniconda3-py310_23.3.1-0-Linux-x86_64.sh
|
||||
- bash Miniconda3-py310_23.3.1-0-Linux-x86_64.sh -b -p /cog/miniconda
|
||||
- /cog/miniconda/bin/conda init bash
|
||||
- /bin/bash -c "source /cog/miniconda/bin/activate && conda create -n myenv python=3.9.16 -y"
|
||||
- /bin/bash -c "source /cog/miniconda/bin/activate && conda activate myenv && conda install -c conda-forge montreal-forced-aligner=2.2.17 openfst=1.8.2 kaldi=5.5.1068 -y"
|
||||
- /bin/bash -c "source /cog/miniconda/bin/activate && conda activate myenv && mfa model download dictionary english_us_arpa && mfa model download acoustic english_us_arpa"
|
||||
- export PATH=/cog/miniconda/envs/myenv/bin:$PATH
|
||||
- git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft
|
||||
- pip install "pydantic<2.0.0"
|
||||
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
|
||||
predict: "predict.py:Predictor"
|
||||
- mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth" "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth"
|
||||
predict: "predict.py:Predictor"
|
||||
|
|
133
predict.py
133
predict.py
|
@ -2,9 +2,7 @@
|
|||
# https://github.com/replicate/cog/blob/main/docs/python.md
|
||||
|
||||
import os
|
||||
import stat
|
||||
import time
|
||||
import warnings
|
||||
import random
|
||||
import getpass
|
||||
import shutil
|
||||
|
@ -12,12 +10,10 @@ import subprocess
|
|||
import torch
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
|
||||
from whisper.model import Whisper, ModelDimensions
|
||||
from whisper.tokenizer import get_tokenizer
|
||||
from cog import BasePredictor, Input, Path, BaseModel
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
os.environ["USER"] = getpass.getuser()
|
||||
|
||||
from data.tokenizer import (
|
||||
|
@ -27,14 +23,12 @@ from data.tokenizer import (
|
|||
from models import voicecraft
|
||||
from inference_tts_scale import inference_one_sample
|
||||
from edit_utils import get_span
|
||||
from inference_speech_editing_scale import get_mask_interval
|
||||
from inference_speech_editing_scale import (
|
||||
inference_one_sample as inference_one_sample_editing,
|
||||
)
|
||||
|
||||
ENV_NAME = "myenv"
|
||||
|
||||
MODEL_URL = "https://weights.replicate.delivery/default/pyp1/VoiceCraft.tar"
|
||||
MODEL_URL = "https://weights.replicate.delivery/default/pyp1/VoiceCraft-models.tar" # all the models are cached and uploaded to replicate.delivery for faster booting
|
||||
MODEL_CACHE = "model_cache"
|
||||
|
||||
|
||||
|
@ -43,9 +37,56 @@ class ModelOutput(BaseModel):
|
|||
generated_audio: Path
|
||||
|
||||
|
||||
class WhisperxAlignModel:
|
||||
def __init__(self):
|
||||
from whisperx import load_align_model
|
||||
|
||||
self.model, self.metadata = load_align_model(
|
||||
language_code="en", device="cuda:0"
|
||||
)
|
||||
|
||||
def align(self, segments, audio_path):
|
||||
from whisperx import align, load_audio
|
||||
|
||||
audio = load_audio(audio_path)
|
||||
return align(
|
||||
segments,
|
||||
self.model,
|
||||
self.metadata,
|
||||
audio,
|
||||
device="cuda:0",
|
||||
return_char_alignments=False,
|
||||
)["segments"]
|
||||
|
||||
|
||||
class WhisperxModel:
|
||||
def __init__(self, model_name, align_model: WhisperxAlignModel, device="cuda"):
|
||||
from whisperx import load_model
|
||||
|
||||
# the model weights are cached from Systran/faster-whisper-base.en etc
|
||||
self.model = load_model(
|
||||
model_name,
|
||||
device,
|
||||
asr_options={
|
||||
"suppress_numerals": True,
|
||||
"max_new_tokens": None,
|
||||
"clip_timestamps": None,
|
||||
"hallucination_silence_threshold": None,
|
||||
},
|
||||
)
|
||||
self.align_model = align_model
|
||||
|
||||
def transcribe(self, audio_path):
|
||||
segments = self.model.transcribe(audio_path, language="en", batch_size=8)[
|
||||
"segments"
|
||||
]
|
||||
return self.align_model.align(segments, audio_path)
|
||||
|
||||
|
||||
class WhisperModel:
|
||||
def __init__(self, model_cache, model_name="base.en", device="cuda"):
|
||||
|
||||
# the model weights are cached from https://github.com/openai/whisper/blob/ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab/whisper/__init__.py#L17
|
||||
with open(f"{model_cache}/{model_name}.pt", "rb") as fp:
|
||||
checkpoint = torch.load(fp, map_location="cpu")
|
||||
dims = ModelDimensions(**checkpoint["dims"])
|
||||
|
@ -105,11 +146,17 @@ class Predictor(BasePredictor):
|
|||
|
||||
self.text_tokenizer = TextTokenizer(backend="espeak")
|
||||
self.audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=self.device)
|
||||
self.transcribe_models = {
|
||||
self.transcribe_models_whisper = {
|
||||
k: WhisperModel(MODEL_CACHE, k, self.device)
|
||||
for k in ["base.en", "small.en", "medium.en"]
|
||||
}
|
||||
|
||||
align_model = WhisperxAlignModel()
|
||||
self.transcribe_models_whisperx = {
|
||||
k: WhisperxModel(f"{MODEL_CACHE}/whisperx_{k.split('.')[0]}", align_model)
|
||||
for k in ["base.en", "small.en", "medium.en"]
|
||||
}
|
||||
|
||||
def predict(
|
||||
self,
|
||||
task: str = Input(
|
||||
|
@ -135,9 +182,16 @@ class Predictor(BasePredictor):
|
|||
default="",
|
||||
),
|
||||
whisper_model: str = Input(
|
||||
description="If orig_transcript is not provided above, choose a Whisper model. Inaccurate transcription may lead to error TTS or speech editing. You can modify the generated transcript and provide it directly to ",
|
||||
choices=["base.en", "small.en", "medium.en"],
|
||||
default="base.en",
|
||||
description="If orig_transcript is not provided above, choose a Whisper or WhisperX model. WhisperX model contains extra alignment steps. Inaccurate transcription may lead to error TTS or speech editing. You can modify the generated transcript and provide it directly to ",
|
||||
choices=[
|
||||
"whisper-base.en",
|
||||
"whisper-small.en",
|
||||
"whisper-medium.en",
|
||||
"whisperx-base.en",
|
||||
"whisperx-small.en",
|
||||
"whisperx-medium.en",
|
||||
],
|
||||
default="whisper-base.en",
|
||||
),
|
||||
target_transcript: str = Input(
|
||||
description="Transcript of the target audio file",
|
||||
|
@ -188,8 +242,19 @@ class Predictor(BasePredictor):
|
|||
|
||||
seed_everything(seed)
|
||||
|
||||
segments = self.transcribe_models[whisper_model].transcribe(str(orig_audio))
|
||||
whisper_model, whisper_model_size = whisper_model.split("-")
|
||||
|
||||
if whisper_model == "whisper":
|
||||
segments = self.transcribe_models_whisper[whisper_model_size].transcribe(
|
||||
str(orig_audio)
|
||||
)
|
||||
else:
|
||||
segments = self.transcribe_models_whisperx[whisper_model_size].transcribe(
|
||||
str(orig_audio)
|
||||
)
|
||||
|
||||
state = get_transcribe_state(segments)
|
||||
|
||||
whisper_transcript = state["transcript"].strip()
|
||||
|
||||
if len(orig_transcript.strip()) == 0:
|
||||
|
@ -204,25 +269,8 @@ class Predictor(BasePredictor):
|
|||
os.makedirs(temp_folder)
|
||||
|
||||
filename = "orig_audio"
|
||||
shutil.copy(orig_audio, f"{temp_folder}/{filename}.wav")
|
||||
audio_fn = str(orig_audio)
|
||||
|
||||
with open(f"{temp_folder}/{filename}.txt", "w") as f:
|
||||
f.write(orig_transcript)
|
||||
|
||||
# run MFA to get the alignment
|
||||
align_temp = f"{temp_folder}/mfa_alignments"
|
||||
|
||||
command = f'/bin/bash -c "source /cog/miniconda/bin/activate && conda activate {ENV_NAME} && mfa align -v --clean -j 1 --output_format csv {temp_folder} english_us_arpa english_us_arpa {align_temp}"'
|
||||
try:
|
||||
subprocess.run(command, shell=True, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("Error:", e)
|
||||
raise RuntimeError("Error running Alignment")
|
||||
|
||||
print("Alignment done!")
|
||||
|
||||
align_fn = f"{align_temp}/{filename}.csv"
|
||||
audio_fn = f"{temp_folder}/{filename}.wav"
|
||||
info = torchaudio.info(audio_fn)
|
||||
audio_dur = info.num_frames / info.sample_rate
|
||||
|
||||
|
@ -262,7 +310,10 @@ class Predictor(BasePredictor):
|
|||
new_span_save = new_span
|
||||
orig_span_save = ",".join([str(item) for item in orig_span_save])
|
||||
new_span_save = ",".join([str(item) for item in new_span_save])
|
||||
start, end = get_mask_interval(align_fn, orig_span_save, edit_type)
|
||||
|
||||
start, end = get_mask_interval_from_word_bounds(
|
||||
state["word_bounds"], orig_span_save, edit_type
|
||||
)
|
||||
|
||||
# span in codec frames
|
||||
morphed_span = (
|
||||
|
@ -359,3 +410,23 @@ def find_closest_cut_off_word(word_bounds, cut_off_sec):
|
|||
break
|
||||
|
||||
return i
|
||||
|
||||
|
||||
def get_mask_interval_from_word_bounds(word_bounds, word_span_ind, editType):
|
||||
tmp = word_span_ind.split(",")
|
||||
s, e = int(tmp[0]), int(tmp[-1])
|
||||
start = None
|
||||
for j, item in enumerate(word_bounds):
|
||||
if j == s:
|
||||
if editType == "insertion":
|
||||
start = float(item["end"])
|
||||
else:
|
||||
start = float(item["start"])
|
||||
if j == e:
|
||||
if editType == "insertion":
|
||||
end = float(item["start"])
|
||||
else:
|
||||
end = float(item["end"])
|
||||
assert start != None
|
||||
break
|
||||
return (start, end)
|
||||
|
|
Loading…
Reference in New Issue