diff --git a/.gitignore b/.gitignore index 17dbc9b..90a560c 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,6 @@ thumbs.db src/audiocraft !/demo/ -!/demo/* \ No newline at end of file +!/demo/* + +.cog/tmp/* \ No newline at end of file diff --git a/cog.yaml b/cog.yaml index 5204813..a020931 100644 --- a/cog.yaml +++ b/cog.yaml @@ -4,30 +4,21 @@ build: gpu: true system_packages: - - "libgl1-mesa-glx" - - "libglib2.0-0" + - libgl1-mesa-glx + - libglib2.0-0 - ffmpeg - espeak-ng - python_version: "3.9.16" + python_version: "3.11" python_packages: - - torch==2.0.1 - - torchaudio==2.0.2 - - xformers==0.0.22 - - tensorboard==2.16.2 + - torch==2.1.0 + - torchaudio==2.1.0 + - xformers - phonemizer==3.2.1 - - datasets==2.16.0 - - torchmetrics==0.11.1 - whisperx==3.1.1 - openai-whisper>=20231117 run: - - curl -O https://repo.anaconda.com/miniconda/Miniconda3-py310_23.3.1-0-Linux-x86_64.sh - - bash Miniconda3-py310_23.3.1-0-Linux-x86_64.sh -b -p /cog/miniconda - - /cog/miniconda/bin/conda init bash - - /bin/bash -c "source /cog/miniconda/bin/activate && conda create -n myenv python=3.9.16 -y" - - /bin/bash -c "source /cog/miniconda/bin/activate && conda activate myenv && conda install -c conda-forge montreal-forced-aligner=2.2.17 openfst=1.8.2 kaldi=5.5.1068 -y" - - /bin/bash -c "source /cog/miniconda/bin/activate && conda activate myenv && mfa model download dictionary english_us_arpa && mfa model download acoustic english_us_arpa" - - export PATH=/cog/miniconda/envs/myenv/bin:$PATH - git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft - pip install "pydantic<2.0.0" - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget -predict: "predict.py:Predictor" \ No newline at end of file + - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth" "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth" +predict: "predict.py:Predictor" diff --git a/predict.py b/predict.py index 001fb8a..4258076 100644 --- a/predict.py +++ b/predict.py @@ -2,9 +2,7 @@ # https://github.com/replicate/cog/blob/main/docs/python.md import os -import stat import time -import warnings import random import getpass import shutil @@ -12,12 +10,10 @@ import subprocess import torch import numpy as np import torchaudio - from whisper.model import Whisper, ModelDimensions from whisper.tokenizer import get_tokenizer from cog import BasePredictor, Input, Path, BaseModel -warnings.filterwarnings("ignore", category=UserWarning) os.environ["USER"] = getpass.getuser() from data.tokenizer import ( @@ -27,14 +23,12 @@ from data.tokenizer import ( from models import voicecraft from inference_tts_scale import inference_one_sample from edit_utils import get_span -from inference_speech_editing_scale import get_mask_interval from inference_speech_editing_scale import ( inference_one_sample as inference_one_sample_editing, ) -ENV_NAME = "myenv" -MODEL_URL = "https://weights.replicate.delivery/default/pyp1/VoiceCraft.tar" +MODEL_URL = "https://weights.replicate.delivery/default/pyp1/VoiceCraft-models.tar" # all the models are cached and uploaded to replicate.delivery for faster booting MODEL_CACHE = "model_cache" @@ -43,9 +37,56 @@ class ModelOutput(BaseModel): generated_audio: Path +class WhisperxAlignModel: + def __init__(self): + from whisperx import load_align_model + + self.model, self.metadata = load_align_model( + language_code="en", device="cuda:0" + ) + + def align(self, segments, audio_path): + from whisperx import align, load_audio + + audio = load_audio(audio_path) + return align( + segments, + self.model, + self.metadata, + audio, + device="cuda:0", + return_char_alignments=False, + )["segments"] + + +class WhisperxModel: + def __init__(self, model_name, align_model: WhisperxAlignModel, device="cuda"): + from whisperx import load_model + + # the model weights are cached from Systran/faster-whisper-base.en etc + self.model = load_model( + model_name, + device, + asr_options={ + "suppress_numerals": True, + "max_new_tokens": None, + "clip_timestamps": None, + "hallucination_silence_threshold": None, + }, + ) + self.align_model = align_model + + def transcribe(self, audio_path): + segments = self.model.transcribe(audio_path, language="en", batch_size=8)[ + "segments" + ] + return self.align_model.align(segments, audio_path) + + class WhisperModel: def __init__(self, model_cache, model_name="base.en", device="cuda"): + # the model weights are cached from https://github.com/openai/whisper/blob/ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab/whisper/__init__.py#L17 with open(f"{model_cache}/{model_name}.pt", "rb") as fp: checkpoint = torch.load(fp, map_location="cpu") dims = ModelDimensions(**checkpoint["dims"]) @@ -105,11 +146,17 @@ class Predictor(BasePredictor): self.text_tokenizer = TextTokenizer(backend="espeak") self.audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=self.device) - self.transcribe_models = { + self.transcribe_models_whisper = { k: WhisperModel(MODEL_CACHE, k, self.device) for k in ["base.en", "small.en", "medium.en"] } + align_model = WhisperxAlignModel() + self.transcribe_models_whisperx = { + k: WhisperxModel(f"{MODEL_CACHE}/whisperx_{k.split('.')[0]}", align_model) + for k in ["base.en", "small.en", "medium.en"] + } + def predict( self, task: str = Input( @@ -135,9 +182,16 @@ class Predictor(BasePredictor): default="", ), whisper_model: str = Input( - description="If orig_transcript is not provided above, choose a Whisper model. Inaccurate transcription may lead to error TTS or speech editing. You can modify the generated transcript and provide it directly to ", - choices=["base.en", "small.en", "medium.en"], - default="base.en", + description="If orig_transcript is not provided above, choose a Whisper or WhisperX model. WhisperX model contains extra alignment steps. Inaccurate transcription may lead to error TTS or speech editing. You can modify the generated transcript and provide it directly to ", + choices=[ + "whisper-base.en", + "whisper-small.en", + "whisper-medium.en", + "whisperx-base.en", + "whisperx-small.en", + "whisperx-medium.en", + ], + default="whisper-base.en", ), target_transcript: str = Input( description="Transcript of the target audio file", @@ -188,8 +242,19 @@ class Predictor(BasePredictor): seed_everything(seed) - segments = self.transcribe_models[whisper_model].transcribe(str(orig_audio)) + whisper_model, whisper_model_size = whisper_model.split("-") + + if whisper_model == "whisper": + segments = self.transcribe_models_whisper[whisper_model_size].transcribe( + str(orig_audio) + ) + else: + segments = self.transcribe_models_whisperx[whisper_model_size].transcribe( + str(orig_audio) + ) + state = get_transcribe_state(segments) + whisper_transcript = state["transcript"].strip() if len(orig_transcript.strip()) == 0: @@ -204,25 +269,8 @@ class Predictor(BasePredictor): os.makedirs(temp_folder) filename = "orig_audio" - shutil.copy(orig_audio, f"{temp_folder}/{filename}.wav") + audio_fn = str(orig_audio) - with open(f"{temp_folder}/{filename}.txt", "w") as f: - f.write(orig_transcript) - - # run MFA to get the alignment - align_temp = f"{temp_folder}/mfa_alignments" - - command = f'/bin/bash -c "source /cog/miniconda/bin/activate && conda activate {ENV_NAME} && mfa align -v --clean -j 1 --output_format csv {temp_folder} english_us_arpa english_us_arpa {align_temp}"' - try: - subprocess.run(command, shell=True, check=True) - except subprocess.CalledProcessError as e: - print("Error:", e) - raise RuntimeError("Error running Alignment") - - print("Alignment done!") - - align_fn = f"{align_temp}/{filename}.csv" - audio_fn = f"{temp_folder}/{filename}.wav" info = torchaudio.info(audio_fn) audio_dur = info.num_frames / info.sample_rate @@ -262,7 +310,10 @@ class Predictor(BasePredictor): new_span_save = new_span orig_span_save = ",".join([str(item) for item in orig_span_save]) new_span_save = ",".join([str(item) for item in new_span_save]) - start, end = get_mask_interval(align_fn, orig_span_save, edit_type) + + start, end = get_mask_interval_from_word_bounds( + state["word_bounds"], orig_span_save, edit_type + ) # span in codec frames morphed_span = ( @@ -359,3 +410,23 @@ def find_closest_cut_off_word(word_bounds, cut_off_sec): break return i + + +def get_mask_interval_from_word_bounds(word_bounds, word_span_ind, editType): + tmp = word_span_ind.split(",") + s, e = int(tmp[0]), int(tmp[-1]) + start = None + for j, item in enumerate(word_bounds): + if j == s: + if editType == "insertion": + start = float(item["end"]) + else: + start = float(item["start"]) + if j == e: + if editType == "insertion": + end = float(item["start"]) + else: + end = float(item["end"]) + assert start != None + break + return (start, end)