From 49a648fa54b1802b647dde8ba9244b39922ba893 Mon Sep 17 00:00:00 2001 From: chenxwh Date: Fri, 5 Apr 2024 15:20:11 +0000 Subject: [PATCH 1/8] Replicate TTS v1 demo --- README.md | 4 +- cog.yaml | 32 ++++++++++ predict.py | 184 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 219 insertions(+), 1 deletion(-) create mode 100644 cog.yaml create mode 100644 predict.py diff --git a/README.md b/README.md index d15654d..9e3af25 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild -[Demo](https://jasonppy.github.io/VoiceCraft_web) [Paper](https://jasonppy.github.io/assets/pdfs/VoiceCraft.pdf) +[Demo](https://jasonppy.github.io/VoiceCraft_web) [Paper](https://jasonppy.github.io/assets/pdfs/VoiceCraft.pdf) +[![Replicate](https://replicate.com/cjwbw/voicecraft/badge)](https://replicate.com/cjwbw/voicecraft) + ### TL;DR diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 0000000..12328c6 --- /dev/null +++ b/cog.yaml @@ -0,0 +1,32 @@ +# Configuration for Cog ⚙️ +# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md + +build: + gpu: true + system_packages: + - "libgl1-mesa-glx" + - "libglib2.0-0" + - ffmpeg + - espeak-ng + # - cmake + python_version: "3.9.16" + python_packages: + - torch==2.0.1 + - torchaudio==2.0.2 + - xformers==0.0.22 + - tensorboard==2.16.2 + - phonemizer==3.2.1 + - datasets==2.16.0 + - torchmetrics==0.11.1 + run: + - curl -O https://repo.anaconda.com/miniconda/Miniconda3-py310_23.3.1-0-Linux-x86_64.sh + - bash Miniconda3-py310_23.3.1-0-Linux-x86_64.sh -b -p /cog/miniconda + - /cog/miniconda/bin/conda init bash + - /bin/bash -c "source /cog/miniconda/bin/activate && conda create -n myenv python=3.9.16 -y" + - /bin/bash -c "source /cog/miniconda/bin/activate && conda activate myenv && conda install -c conda-forge montreal-forced-aligner=2.2.17 openfst=1.8.2 kaldi=5.5.1068 -y" + - /bin/bash -c "source /cog/miniconda/bin/activate && conda activate myenv && mfa model download dictionary english_us_arpa && mfa model download acoustic english_us_arpa" + - export PATH=/cog/miniconda/envs/myenv/bin:$PATH + - git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft + - pip install "pydantic<2.0.0" + - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget +predict: "predict.py:Predictor" \ No newline at end of file diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..d7a4260 --- /dev/null +++ b/predict.py @@ -0,0 +1,184 @@ +# Prediction interface for Cog ⚙️ +# https://github.com/replicate/cog/blob/main/docs/python.md + +import os +import time +import numpy as np +import random +import getpass +import torch +import torchaudio +import shutil +import subprocess +import sys + +os.environ["USER"] = getpass.getuser() + +from data.tokenizer import ( + AudioTokenizer, + TextTokenizer, +) +from cog import BasePredictor, Input, Path +from models import voicecraft +from inference_tts_scale import inference_one_sample + +ENV_NAME = "myenv" +# sys.path.append(f"/cog/miniconda/envs/{ENV_NAME}/lib/python3.10/site-packages") + +MODEL_URL = "https://weights.replicate.delivery/default/VoiceCraft.tar" +MODEL_CACHE = "model_cache" + + +def download_weights(url, dest): + start = time.time() + print("downloading url: ", url) + print("downloading to: ", dest) + subprocess.check_call(["pget", "-x", url, dest], close_fds=False) + print("downloading took: ", time.time() - start) + + +class Predictor(BasePredictor): + def setup(self): + """Load the model into memory to make running multiple predictions efficient""" + self.device = "cuda" + + voicecraft_name = "giga830M.pth" # or giga330M.pth + + if not os.path.exists(MODEL_CACHE): + download_weights(MODEL_URL, MODEL_CACHE) + + encodec_fn = f"{MODEL_CACHE}/encodec_4cb2048_giga.th" + ckpt_fn = f"{MODEL_CACHE}/{voicecraft_name}" + + self.ckpt = torch.load(ckpt_fn, map_location="cpu") + self.model = voicecraft.VoiceCraft(self.ckpt["config"]) + self.model.load_state_dict(self.ckpt["model"]) + self.model.to(self.device) + self.model.eval() + + self.phn2num = self.ckpt["phn2num"] + + self.text_tokenizer = TextTokenizer(backend="espeak") + self.audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=self.device) + + def predict( + self, + orig_audio: Path = Input(description="Original audio file"), + orig_transcript: str = Input( + description="Transcript of the original audio file. You can use models such as https://replicate.com/openai/whisper and https://replicate.com/vaibhavs10/incredibly-fast-whisper to get the transcript (and modify it if it's not accurate)", + ), + cut_off_sec: float = Input( + description="The first seconds of the original audio that are used for zero-shot text-to-speech (TTS). 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec", + default=3.01, + ), + orig_transcript_until_cutoff_time: str = Input( + description="Transcript of the original audio file until the cut_off_sec specified above. This process will be improved and made automatically later", + ), + target_transcript: str = Input( + description="Transcript of the target audio file", + ), + temperature: float = Input( + description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic, + ge=0.01, + le=5, + default=1, + ), + top_p: float = Input( + description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens", + ge=0.0, + le=1.0, + default=0.8, + ), + sampling_rate: int = Input( + description="Specify the sampling rate of the audio codec", default=16000 + ), + seed: int = Input( + description="Random seed. Leave blank to randomize the seed", default=None + ), + ) -> Path: + """Run a single prediction on the model""" + if seed is None: + seed = int.from_bytes(os.urandom(2), "big") + print(f"Using seed: {seed}") + + seed_everything(seed) + + temp_folder = "exp_temp" + if os.path.exists(temp_folder): + shutil.rmtree(temp_folder) + + os.makedirs(temp_folder) + os.system(f"cp {str(orig_audio)} {temp_folder}") + # filename = os.path.splitext(orig_audio.split("/")[-1])[0] + with open(f"{temp_folder}/orig_audio_file.txt", "w") as f: + f.write(orig_transcript) + + # run MFA to get the alignment + align_temp = f"{temp_folder}/mfa_alignments" + + command = f'/bin/bash -c "source /cog/miniconda/bin/activate && conda activate {ENV_NAME} && mfa align -v --clean -j 1 --output_format csv {temp_folder} english_us_arpa english_us_arpa {align_temp}"' + try: + subprocess.run(command, shell=True, check=True) + except subprocess.CalledProcessError as e: + print("Error:", e) + print("Alignment done!") + + audio_fn = str(orig_audio) # f"{temp_folder}/{filename}.wav" + info = torchaudio.info(audio_fn) + audio_dur = info.num_frames / info.sample_rate + + assert ( + cut_off_sec < audio_dur + ), f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}" + prompt_end_frame = int(cut_off_sec * info.sample_rate) + + codec_sr = 50 + top_k = 0 + silence_tokens = [1388, 1898, 131] + kvcache = 1 # NOTE if OOM, change this to 0, or try the 330M model + + # NOTE adjust the below three arguments if the generation is not as good + stop_repetition = 3 # NOTE if the model generate long silence, reduce the stop_repetition to 3, 2 or even 1 + sample_batch_size = 4 # NOTE: if the if there are long silence or unnaturally strecthed words, increase sample_batch_size to 5 or higher. What this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest. So if the speech rate of the generated is too fast change it to a smaller number. + + decode_config = { + "top_k": top_k, + "top_p": top_p, + "temperature": temperature, + "stop_repetition": stop_repetition, + "kvcache": kvcache, + "codec_audio_sr": sampling_rate, + "codec_sr": codec_sr, + "silence_tokens": silence_tokens, + "sample_batch_size": sample_batch_size, + } + concated_audio, gen_audio = inference_one_sample( + self.model, + self.ckpt["config"], + self.phn2num, + self.text_tokenizer, + self.audio_tokenizer, + audio_fn, + orig_transcript_until_cutoff_time.strip() + "" + target_transcript.strip(), + self.device, + decode_config, + prompt_end_frame, + ) + + # save segments for comparison + concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu() + + out = "/tmp/out.wav" + torchaudio.save(out, gen_audio, sampling_rate) + torchaudio.save("out.wav", gen_audio, sampling_rate) + return Path(out) + + +def seed_everything(seed): + os.environ["PYTHONHASHSEED"] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True From 023d4b1c6c04786f39a2977c635338be33be7924 Mon Sep 17 00:00:00 2001 From: chenxwh Date: Fri, 5 Apr 2024 17:23:39 +0000 Subject: [PATCH 2/8] replicate demo --- cog.yaml | 1 - predict.py | 156 ++++++++++++++++++++++++++++++++++++++++------------- 2 files changed, 119 insertions(+), 38 deletions(-) diff --git a/cog.yaml b/cog.yaml index 12328c6..a771ad5 100644 --- a/cog.yaml +++ b/cog.yaml @@ -8,7 +8,6 @@ build: - "libglib2.0-0" - ffmpeg - espeak-ng - # - cmake python_version: "3.9.16" python_packages: - torch==2.0.1 diff --git a/predict.py b/predict.py index d7a4260..d8a2890 100644 --- a/predict.py +++ b/predict.py @@ -11,7 +11,9 @@ import torchaudio import shutil import subprocess import sys +import warnings +warnings.filterwarnings("ignore", category=UserWarning) os.environ["USER"] = getpass.getuser() from data.tokenizer import ( @@ -21,9 +23,14 @@ from data.tokenizer import ( from cog import BasePredictor, Input, Path from models import voicecraft from inference_tts_scale import inference_one_sample +from edit_utils import get_span +from inference_speech_editing_scale import get_mask_interval +from inference_speech_editing_scale import ( + inference_one_sample as inference_one_sample_editing, +) ENV_NAME = "myenv" -# sys.path.append(f"/cog/miniconda/envs/{ENV_NAME}/lib/python3.10/site-packages") + MODEL_URL = "https://weights.replicate.delivery/default/VoiceCraft.tar" MODEL_CACHE = "model_cache" @@ -63,22 +70,33 @@ class Predictor(BasePredictor): def predict( self, + task: str = Input( + description="Choose a task. For zero-shot text-to-speech, you also need to specify the cut_off_sec of the original audio to be used for zero-shot generation and the transcript until the cut_off_sec", + choices=[ + "speech_editing-substitution", + "speech_editing-insertion", + "speech_editing-sdeletion", + "zero-shot text-to-speech", + ], + default="speech_editing-substitution", + ), orig_audio: Path = Input(description="Original audio file"), orig_transcript: str = Input( description="Transcript of the original audio file. You can use models such as https://replicate.com/openai/whisper and https://replicate.com/vaibhavs10/incredibly-fast-whisper to get the transcript (and modify it if it's not accurate)", ), - cut_off_sec: float = Input( - description="The first seconds of the original audio that are used for zero-shot text-to-speech (TTS). 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec", - default=3.01, - ), - orig_transcript_until_cutoff_time: str = Input( - description="Transcript of the original audio file until the cut_off_sec specified above. This process will be improved and made automatically later", - ), target_transcript: str = Input( description="Transcript of the target audio file", ), + cut_off_sec: float = Input( + description="Valid/Required for zero-shot text-to-speech task. The first seconds of the original audio that are used for zero-shot text-to-speech (TTS). 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec", + default=3.01, + ), + orig_transcript_until_cutoff_time: str = Input( + description="Valid/Required for zero-shot text-to-speech task. Transcript of the original audio file until the cut_off_sec specified above. This process will be improved and made automatically later", + default=None, + ), temperature: float = Input( - description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic, + description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic", ge=0.01, le=5, default=1, @@ -89,6 +107,10 @@ class Predictor(BasePredictor): le=1.0, default=0.8, ), + stop_repetition: int = Input( + default=-1, + description=" -1 means do not adjust prob of silence tokens. if there are long silence or unnaturally strecthed words, increase sample_batch_size to 2, 3 or even 4", + ), sampling_rate: int = Input( description="Specify the sampling rate of the audio codec", default=16000 ), @@ -97,20 +119,27 @@ class Predictor(BasePredictor): ), ) -> Path: """Run a single prediction on the model""" + + if task == "zero-shot text-to-speech": + assert ( + orig_transcript_until_cutoff_time is not None + ), "Please provide orig_transcript_until_cutoff_time for zero-shot text-to-speech task." if seed is None: seed = int.from_bytes(os.urandom(2), "big") print(f"Using seed: {seed}") seed_everything(seed) - temp_folder = "exp_temp" + temp_folder = "exp_dir" if os.path.exists(temp_folder): shutil.rmtree(temp_folder) os.makedirs(temp_folder) - os.system(f"cp {str(orig_audio)} {temp_folder}") - # filename = os.path.splitext(orig_audio.split("/")[-1])[0] - with open(f"{temp_folder}/orig_audio_file.txt", "w") as f: + + filename = "orig_audio" + shutil.copy(orig_audio, f"{temp_folder}/{filename}.wav") + + with open(f"{temp_folder}/{filename}.txt", "w") as f: f.write(orig_transcript) # run MFA to get the alignment @@ -121,26 +150,61 @@ class Predictor(BasePredictor): subprocess.run(command, shell=True, check=True) except subprocess.CalledProcessError as e: print("Error:", e) + raise RuntimeError("Error running Alignment") + print("Alignment done!") - audio_fn = str(orig_audio) # f"{temp_folder}/{filename}.wav" + align_fn = f"{align_temp}/{filename}.csv" + audio_fn = f"{temp_folder}/{filename}.wav" info = torchaudio.info(audio_fn) audio_dur = info.num_frames / info.sample_rate - assert ( - cut_off_sec < audio_dur - ), f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}" - prompt_end_frame = int(cut_off_sec * info.sample_rate) - + # hyperparameters for inference + left_margin = 0.08 + right_margin = 0.08 codec_sr = 50 top_k = 0 silence_tokens = [1388, 1898, 131] - kvcache = 1 # NOTE if OOM, change this to 0, or try the 330M model + kvcache = 1 if task == "zero-shot text-to-speech" else 0 - # NOTE adjust the below three arguments if the generation is not as good - stop_repetition = 3 # NOTE if the model generate long silence, reduce the stop_repetition to 3, 2 or even 1 sample_batch_size = 4 # NOTE: if the if there are long silence or unnaturally strecthed words, increase sample_batch_size to 5 or higher. What this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest. So if the speech rate of the generated is too fast change it to a smaller number. + if task == "": + assert ( + cut_off_sec < audio_dur + ), f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}" + prompt_end_frame = int(cut_off_sec * info.sample_rate) + + else: + + edit_type = task.split("-")[-1] + orig_span, new_span = get_span( + orig_transcript, target_transcript, edit_type + ) + if orig_span[0] > orig_span[1]: + RuntimeError(f"example {audio_fn} failed") + if orig_span[0] == orig_span[1]: + orig_span_save = [orig_span[0]] + else: + orig_span_save = orig_span + if new_span[0] == new_span[1]: + new_span_save = [new_span[0]] + else: + new_span_save = new_span + orig_span_save = ",".join([str(item) for item in orig_span_save]) + new_span_save = ",".join([str(item) for item in new_span_save]) + start, end = get_mask_interval(align_fn, orig_span_save, edit_type) + + # span in codec frames + morphed_span = ( + max(start - left_margin, 1 / codec_sr), + min(end + right_margin, audio_dur), + ) # in seconds + mask_interval = [ + [round(morphed_span[0] * codec_sr), round(morphed_span[1] * codec_sr)] + ] + mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now + decode_config = { "top_k": top_k, "top_p": top_p, @@ -150,27 +214,45 @@ class Predictor(BasePredictor): "codec_audio_sr": sampling_rate, "codec_sr": codec_sr, "silence_tokens": silence_tokens, - "sample_batch_size": sample_batch_size, } - concated_audio, gen_audio = inference_one_sample( - self.model, - self.ckpt["config"], - self.phn2num, - self.text_tokenizer, - self.audio_tokenizer, - audio_fn, - orig_transcript_until_cutoff_time.strip() + "" + target_transcript.strip(), - self.device, - decode_config, - prompt_end_frame, - ) + + if task == "zero-shot text-to-speech": + decode_config["sample_batch_size"] = sample_batch_size + + concated_audio, gen_audio = inference_one_sample( + self.model, + self.ckpt["config"], + self.phn2num, + self.text_tokenizer, + self.audio_tokenizer, + audio_fn, + orig_transcript_until_cutoff_time.strip() + + "" + + target_transcript.strip(), + self.device, + decode_config, + prompt_end_frame, + ) + + else: + orig_audio, gen_audio = inference_one_sample_editing( + self.model, + self.ckpt["config"], + self.phn2num, + self.text_tokenizer, + self.audio_tokenizer, + audio_fn, + target_transcript, + mask_interval, + self.device, + decode_config, + ) # save segments for comparison - concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu() + gen_audio = gen_audio[0].cpu() out = "/tmp/out.wav" torchaudio.save(out, gen_audio, sampling_rate) - torchaudio.save("out.wav", gen_audio, sampling_rate) return Path(out) From b8eca5a2d44a4438dc5b8e30593ced2bef89d51f Mon Sep 17 00:00:00 2001 From: chenxwh Date: Fri, 5 Apr 2024 17:58:09 +0000 Subject: [PATCH 3/8] replicate demo --- predict.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/predict.py b/predict.py index d8a2890..8a0fee9 100644 --- a/predict.py +++ b/predict.py @@ -75,7 +75,7 @@ class Predictor(BasePredictor): choices=[ "speech_editing-substitution", "speech_editing-insertion", - "speech_editing-sdeletion", + "speech_editing-deletion", "zero-shot text-to-speech", ], default="speech_editing-substitution", @@ -89,7 +89,7 @@ class Predictor(BasePredictor): ), cut_off_sec: float = Input( description="Valid/Required for zero-shot text-to-speech task. The first seconds of the original audio that are used for zero-shot text-to-speech (TTS). 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec", - default=3.01, + default=None, ), orig_transcript_until_cutoff_time: str = Input( description="Valid/Required for zero-shot text-to-speech task. Transcript of the original audio file until the cut_off_sec specified above. This process will be improved and made automatically later", @@ -123,7 +123,8 @@ class Predictor(BasePredictor): if task == "zero-shot text-to-speech": assert ( orig_transcript_until_cutoff_time is not None - ), "Please provide orig_transcript_until_cutoff_time for zero-shot text-to-speech task." + and cut_off_sec is not None + ), "Please provide cut_off_sec and orig_transcript_until_cutoff_time for zero-shot text-to-speech task." if seed is None: seed = int.from_bytes(os.urandom(2), "big") print(f"Using seed: {seed}") @@ -169,7 +170,7 @@ class Predictor(BasePredictor): sample_batch_size = 4 # NOTE: if the if there are long silence or unnaturally strecthed words, increase sample_batch_size to 5 or higher. What this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest. So if the speech rate of the generated is too fast change it to a smaller number. - if task == "": + if task == "zero-shot text-to-speech": assert ( cut_off_sec < audio_dur ), f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}" From 0da8ee4b7a2aacb994c424e827f649e7b0f46041 Mon Sep 17 00:00:00 2001 From: chenxwh Date: Sun, 14 Apr 2024 12:15:23 +0000 Subject: [PATCH 4/8] update replicate demo --- cog.yaml | 2 + predict.py | 206 ++++++++++++++++++++++++++++++++++++++--------------- 2 files changed, 152 insertions(+), 56 deletions(-) diff --git a/cog.yaml b/cog.yaml index a771ad5..5204813 100644 --- a/cog.yaml +++ b/cog.yaml @@ -17,6 +17,8 @@ build: - phonemizer==3.2.1 - datasets==2.16.0 - torchmetrics==0.11.1 + - whisperx==3.1.1 + - openai-whisper>=20231117 run: - curl -O https://repo.anaconda.com/miniconda/Miniconda3-py310_23.3.1-0-Linux-x86_64.sh - bash Miniconda3-py310_23.3.1-0-Linux-x86_64.sh -b -p /cog/miniconda diff --git a/predict.py b/predict.py index 8a0fee9..001fb8a 100644 --- a/predict.py +++ b/predict.py @@ -2,16 +2,20 @@ # https://github.com/replicate/cog/blob/main/docs/python.md import os +import stat import time -import numpy as np +import warnings import random import getpass -import torch -import torchaudio import shutil import subprocess -import sys -import warnings +import torch +import numpy as np +import torchaudio + +from whisper.model import Whisper, ModelDimensions +from whisper.tokenizer import get_tokenizer +from cog import BasePredictor, Input, Path, BaseModel warnings.filterwarnings("ignore", category=UserWarning) os.environ["USER"] = getpass.getuser() @@ -20,7 +24,6 @@ from data.tokenizer import ( AudioTokenizer, TextTokenizer, ) -from cog import BasePredictor, Input, Path from models import voicecraft from inference_tts_scale import inference_one_sample from edit_utils import get_span @@ -31,11 +34,38 @@ from inference_speech_editing_scale import ( ENV_NAME = "myenv" - -MODEL_URL = "https://weights.replicate.delivery/default/VoiceCraft.tar" +MODEL_URL = "https://weights.replicate.delivery/default/pyp1/VoiceCraft.tar" MODEL_CACHE = "model_cache" +class ModelOutput(BaseModel): + whisper_transcript_orig_audio: str + generated_audio: Path + + +class WhisperModel: + def __init__(self, model_cache, model_name="base.en", device="cuda"): + + with open(f"{model_cache}/{model_name}.pt", "rb") as fp: + checkpoint = torch.load(fp, map_location="cpu") + dims = ModelDimensions(**checkpoint["dims"]) + self.model = Whisper(dims) + self.model.load_state_dict(checkpoint["model_state_dict"]) + self.model.to(device) + + tokenizer = get_tokenizer(multilingual=False) + self.supress_tokens = [-1] + [ + i + for i in range(tokenizer.eot) + if all(c in "0123456789" for c in tokenizer.decode([i]).removeprefix(" ")) + ] + + def transcribe(self, audio_path): + return self.model.transcribe( + audio_path, suppress_tokens=self.supress_tokens, word_timestamps=True + )["segments"] + + def download_weights(url, dest): start = time.time() print("downloading url: ", url) @@ -49,56 +79,87 @@ class Predictor(BasePredictor): """Load the model into memory to make running multiple predictions efficient""" self.device = "cuda" - voicecraft_name = "giga830M.pth" # or giga330M.pth - if not os.path.exists(MODEL_CACHE): download_weights(MODEL_URL, MODEL_CACHE) encodec_fn = f"{MODEL_CACHE}/encodec_4cb2048_giga.th" - ckpt_fn = f"{MODEL_CACHE}/{voicecraft_name}" + self.models, self.ckpt, self.phn2num = {}, {}, {} + for voicecraft_name in [ + "giga830M.pth", + "giga330M.pth", + "gigaHalfLibri330M_TTSEnhanced_max16s.pth", + ]: + ckpt_fn = f"{MODEL_CACHE}/{voicecraft_name}" - self.ckpt = torch.load(ckpt_fn, map_location="cpu") - self.model = voicecraft.VoiceCraft(self.ckpt["config"]) - self.model.load_state_dict(self.ckpt["model"]) - self.model.to(self.device) - self.model.eval() + self.ckpt[voicecraft_name] = torch.load(ckpt_fn, map_location="cpu") + self.models[voicecraft_name] = voicecraft.VoiceCraft( + self.ckpt[voicecraft_name]["config"] + ) + self.models[voicecraft_name].load_state_dict( + self.ckpt[voicecraft_name]["model"] + ) + self.models[voicecraft_name].to(self.device) + self.models[voicecraft_name].eval() - self.phn2num = self.ckpt["phn2num"] + self.phn2num[voicecraft_name] = self.ckpt[voicecraft_name]["phn2num"] self.text_tokenizer = TextTokenizer(backend="espeak") self.audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=self.device) + self.transcribe_models = { + k: WhisperModel(MODEL_CACHE, k, self.device) + for k in ["base.en", "small.en", "medium.en"] + } def predict( self, task: str = Input( - description="Choose a task. For zero-shot text-to-speech, you also need to specify the cut_off_sec of the original audio to be used for zero-shot generation and the transcript until the cut_off_sec", + description="Choose a task", choices=[ "speech_editing-substitution", "speech_editing-insertion", "speech_editing-deletion", "zero-shot text-to-speech", ], - default="speech_editing-substitution", + default="zero-shot text-to-speech", + ), + voicecraft_model: str = Input( + description="Choose a model", + choices=["giga830M.pth", "giga330M.pth", "giga330M_TTSEnhanced.pth"], + default="giga330M_TTSEnhanced.pth", + ), + orig_audio: Path = Input( + description="Original audio file. WhisperX small.en model will be used for transcription" ), - orig_audio: Path = Input(description="Original audio file"), orig_transcript: str = Input( - description="Transcript of the original audio file. You can use models such as https://replicate.com/openai/whisper and https://replicate.com/vaibhavs10/incredibly-fast-whisper to get the transcript (and modify it if it's not accurate)", + description="Optionally provide the transcript of the input audio. Leave it blank to use the whisper model below to generate the transcript. Inaccurate transcription may lead to error TTS or speech editing", + default="", + ), + whisper_model: str = Input( + description="If orig_transcript is not provided above, choose a Whisper model. Inaccurate transcription may lead to error TTS or speech editing. You can modify the generated transcript and provide it directly to ", + choices=["base.en", "small.en", "medium.en"], + default="base.en", ), target_transcript: str = Input( description="Transcript of the target audio file", ), cut_off_sec: float = Input( - description="Valid/Required for zero-shot text-to-speech task. The first seconds of the original audio that are used for zero-shot text-to-speech (TTS). 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec", - default=None, + description="Only used for for zero-shot text-to-speech task. The first seconds of the original audio that are used for zero-shot text-to-speech. 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec", + default=3.01, ), - orig_transcript_until_cutoff_time: str = Input( - description="Valid/Required for zero-shot text-to-speech task. Transcript of the original audio file until the cut_off_sec specified above. This process will be improved and made automatically later", - default=None, + kvcache: int = Input( + description="Set to 0 to use less VRAM, but with slower inference", + default=1, + ), + left_margin: float = Input( + description="Margin to the left of the editing segment", + default=0.08, + ), + right_margin: float = Input( + description="Margin to the right of the editing segment", + default=0.08, ), temperature: float = Input( - description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic", - ge=0.01, - le=5, + description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic. Do not recommend to change", default=1, ), top_p: float = Input( @@ -109,28 +170,33 @@ class Predictor(BasePredictor): ), stop_repetition: int = Input( default=-1, - description=" -1 means do not adjust prob of silence tokens. if there are long silence or unnaturally strecthed words, increase sample_batch_size to 2, 3 or even 4", + description=" -1 means do not adjust prob of silence tokens. if there are long silence or unnaturally stretched words, increase sample_batch_size to 2, 3 or even 4", ), - sampling_rate: int = Input( - description="Specify the sampling rate of the audio codec", default=16000 + sample_batch_size: int = Input( + description="The higher the number, the faster the output will be. Under the hood, the model will generate this many samples and choose the shortest one", + default=4, ), seed: int = Input( description="Random seed. Leave blank to randomize the seed", default=None ), - ) -> Path: + ) -> ModelOutput: """Run a single prediction on the model""" - if task == "zero-shot text-to-speech": - assert ( - orig_transcript_until_cutoff_time is not None - and cut_off_sec is not None - ), "Please provide cut_off_sec and orig_transcript_until_cutoff_time for zero-shot text-to-speech task." if seed is None: seed = int.from_bytes(os.urandom(2), "big") print(f"Using seed: {seed}") seed_everything(seed) + segments = self.transcribe_models[whisper_model].transcribe(str(orig_audio)) + state = get_transcribe_state(segments) + whisper_transcript = state["transcript"].strip() + + if len(orig_transcript.strip()) == 0: + orig_transcript = whisper_transcript + + print(f"The transcript from the Whisper model: {whisper_transcript}") + temp_folder = "exp_dir" if os.path.exists(temp_folder): shutil.rmtree(temp_folder) @@ -161,14 +227,13 @@ class Predictor(BasePredictor): audio_dur = info.num_frames / info.sample_rate # hyperparameters for inference - left_margin = 0.08 - right_margin = 0.08 + codec_audio_sr = 16000 codec_sr = 50 top_k = 0 silence_tokens = [1388, 1898, 131] - kvcache = 1 if task == "zero-shot text-to-speech" else 0 - sample_batch_size = 4 # NOTE: if the if there are long silence or unnaturally strecthed words, increase sample_batch_size to 5 or higher. What this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest. So if the speech rate of the generated is too fast change it to a smaller number. + if voicecraft_model == "giga330M_TTSEnhanced.pth": + voicecraft_model = "gigaHalfLibri330M_TTSEnhanced_max16s.pth" if task == "zero-shot text-to-speech": assert ( @@ -176,8 +241,11 @@ class Predictor(BasePredictor): ), f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}" prompt_end_frame = int(cut_off_sec * info.sample_rate) + idx = find_closest_cut_off_word(state["word_bounds"], cut_off_sec) + orig_transcript_until_cutoff_time = "".join( + [word_bound["word"] for word_bound in state["word_bounds"][:idx]] + ) else: - edit_type = task.split("-")[-1] orig_span, new_span = get_span( orig_transcript, target_transcript, edit_type @@ -212,18 +280,17 @@ class Predictor(BasePredictor): "temperature": temperature, "stop_repetition": stop_repetition, "kvcache": kvcache, - "codec_audio_sr": sampling_rate, + "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr, "silence_tokens": silence_tokens, } if task == "zero-shot text-to-speech": decode_config["sample_batch_size"] = sample_batch_size - - concated_audio, gen_audio = inference_one_sample( - self.model, - self.ckpt["config"], - self.phn2num, + _, gen_audio = inference_one_sample( + self.models[voicecraft_model], + self.ckpt[voicecraft_model]["config"], + self.phn2num[voicecraft_model], self.text_tokenizer, self.audio_tokenizer, audio_fn, @@ -234,12 +301,11 @@ class Predictor(BasePredictor): decode_config, prompt_end_frame, ) - else: - orig_audio, gen_audio = inference_one_sample_editing( - self.model, - self.ckpt["config"], - self.phn2num, + _, gen_audio = inference_one_sample_editing( + self.models[voicecraft_model], + self.ckpt[voicecraft_model]["config"], + self.phn2num[voicecraft_model], self.text_tokenizer, self.audio_tokenizer, audio_fn, @@ -253,8 +319,10 @@ class Predictor(BasePredictor): gen_audio = gen_audio[0].cpu() out = "/tmp/out.wav" - torchaudio.save(out, gen_audio, sampling_rate) - return Path(out) + torchaudio.save(out, gen_audio, codec_audio_sr) + return ModelOutput( + generated_audio=Path(out), whisper_transcript_orig_audio=whisper_transcript + ) def seed_everything(seed): @@ -265,3 +333,29 @@ def seed_everything(seed): torch.cuda.manual_seed(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True + + +def get_transcribe_state(segments): + words_info = [word_info for segment in segments for word_info in segment["words"]] + return { + "transcript": " ".join([segment["text"].strip() for segment in segments]), + "word_bounds": [ + {"word": word["word"], "start": word["start"], "end": word["end"]} + for word in words_info + ], + } + + +def find_closest_cut_off_word(word_bounds, cut_off_sec): + min_distance = float("inf") + + for i, word_bound in enumerate(word_bounds): + distance = abs(word_bound["start"] - cut_off_sec) + + if distance < min_distance: + min_distance = distance + + if word_bound["end"] > cut_off_sec: + break + + return i From 9746a1f60c4d0c0289c6914eb7cc2b662e7bd5ca Mon Sep 17 00:00:00 2001 From: chenxwh Date: Fri, 19 Apr 2024 10:46:00 +0000 Subject: [PATCH 5/8] update with whisperx --- .gitignore | 4 +- cog.yaml | 25 ++++------ predict.py | 133 ++++++++++++++++++++++++++++++++++++++++------------- 3 files changed, 113 insertions(+), 49 deletions(-) diff --git a/.gitignore b/.gitignore index 17dbc9b..90a560c 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,6 @@ thumbs.db src/audiocraft !/demo/ -!/demo/* \ No newline at end of file +!/demo/* + +.cog/tmp/* \ No newline at end of file diff --git a/cog.yaml b/cog.yaml index 5204813..a020931 100644 --- a/cog.yaml +++ b/cog.yaml @@ -4,30 +4,21 @@ build: gpu: true system_packages: - - "libgl1-mesa-glx" - - "libglib2.0-0" + - libgl1-mesa-glx + - libglib2.0-0 - ffmpeg - espeak-ng - python_version: "3.9.16" + python_version: "3.11" python_packages: - - torch==2.0.1 - - torchaudio==2.0.2 - - xformers==0.0.22 - - tensorboard==2.16.2 + - torch==2.1.0 + - torchaudio==2.1.0 + - xformers - phonemizer==3.2.1 - - datasets==2.16.0 - - torchmetrics==0.11.1 - whisperx==3.1.1 - openai-whisper>=20231117 run: - - curl -O https://repo.anaconda.com/miniconda/Miniconda3-py310_23.3.1-0-Linux-x86_64.sh - - bash Miniconda3-py310_23.3.1-0-Linux-x86_64.sh -b -p /cog/miniconda - - /cog/miniconda/bin/conda init bash - - /bin/bash -c "source /cog/miniconda/bin/activate && conda create -n myenv python=3.9.16 -y" - - /bin/bash -c "source /cog/miniconda/bin/activate && conda activate myenv && conda install -c conda-forge montreal-forced-aligner=2.2.17 openfst=1.8.2 kaldi=5.5.1068 -y" - - /bin/bash -c "source /cog/miniconda/bin/activate && conda activate myenv && mfa model download dictionary english_us_arpa && mfa model download acoustic english_us_arpa" - - export PATH=/cog/miniconda/envs/myenv/bin:$PATH - git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft - pip install "pydantic<2.0.0" - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget -predict: "predict.py:Predictor" \ No newline at end of file + - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth" "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth" +predict: "predict.py:Predictor" diff --git a/predict.py b/predict.py index 001fb8a..4258076 100644 --- a/predict.py +++ b/predict.py @@ -2,9 +2,7 @@ # https://github.com/replicate/cog/blob/main/docs/python.md import os -import stat import time -import warnings import random import getpass import shutil @@ -12,12 +10,10 @@ import subprocess import torch import numpy as np import torchaudio - from whisper.model import Whisper, ModelDimensions from whisper.tokenizer import get_tokenizer from cog import BasePredictor, Input, Path, BaseModel -warnings.filterwarnings("ignore", category=UserWarning) os.environ["USER"] = getpass.getuser() from data.tokenizer import ( @@ -27,14 +23,12 @@ from data.tokenizer import ( from models import voicecraft from inference_tts_scale import inference_one_sample from edit_utils import get_span -from inference_speech_editing_scale import get_mask_interval from inference_speech_editing_scale import ( inference_one_sample as inference_one_sample_editing, ) -ENV_NAME = "myenv" -MODEL_URL = "https://weights.replicate.delivery/default/pyp1/VoiceCraft.tar" +MODEL_URL = "https://weights.replicate.delivery/default/pyp1/VoiceCraft-models.tar" # all the models are cached and uploaded to replicate.delivery for faster booting MODEL_CACHE = "model_cache" @@ -43,9 +37,56 @@ class ModelOutput(BaseModel): generated_audio: Path +class WhisperxAlignModel: + def __init__(self): + from whisperx import load_align_model + + self.model, self.metadata = load_align_model( + language_code="en", device="cuda:0" + ) + + def align(self, segments, audio_path): + from whisperx import align, load_audio + + audio = load_audio(audio_path) + return align( + segments, + self.model, + self.metadata, + audio, + device="cuda:0", + return_char_alignments=False, + )["segments"] + + +class WhisperxModel: + def __init__(self, model_name, align_model: WhisperxAlignModel, device="cuda"): + from whisperx import load_model + + # the model weights are cached from Systran/faster-whisper-base.en etc + self.model = load_model( + model_name, + device, + asr_options={ + "suppress_numerals": True, + "max_new_tokens": None, + "clip_timestamps": None, + "hallucination_silence_threshold": None, + }, + ) + self.align_model = align_model + + def transcribe(self, audio_path): + segments = self.model.transcribe(audio_path, language="en", batch_size=8)[ + "segments" + ] + return self.align_model.align(segments, audio_path) + + class WhisperModel: def __init__(self, model_cache, model_name="base.en", device="cuda"): + # the model weights are cached from https://github.com/openai/whisper/blob/ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab/whisper/__init__.py#L17 with open(f"{model_cache}/{model_name}.pt", "rb") as fp: checkpoint = torch.load(fp, map_location="cpu") dims = ModelDimensions(**checkpoint["dims"]) @@ -105,11 +146,17 @@ class Predictor(BasePredictor): self.text_tokenizer = TextTokenizer(backend="espeak") self.audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=self.device) - self.transcribe_models = { + self.transcribe_models_whisper = { k: WhisperModel(MODEL_CACHE, k, self.device) for k in ["base.en", "small.en", "medium.en"] } + align_model = WhisperxAlignModel() + self.transcribe_models_whisperx = { + k: WhisperxModel(f"{MODEL_CACHE}/whisperx_{k.split('.')[0]}", align_model) + for k in ["base.en", "small.en", "medium.en"] + } + def predict( self, task: str = Input( @@ -135,9 +182,16 @@ class Predictor(BasePredictor): default="", ), whisper_model: str = Input( - description="If orig_transcript is not provided above, choose a Whisper model. Inaccurate transcription may lead to error TTS or speech editing. You can modify the generated transcript and provide it directly to ", - choices=["base.en", "small.en", "medium.en"], - default="base.en", + description="If orig_transcript is not provided above, choose a Whisper or WhisperX model. WhisperX model contains extra alignment steps. Inaccurate transcription may lead to error TTS or speech editing. You can modify the generated transcript and provide it directly to ", + choices=[ + "whisper-base.en", + "whisper-small.en", + "whisper-medium.en", + "whisperx-base.en", + "whisperx-small.en", + "whisperx-medium.en", + ], + default="whisper-base.en", ), target_transcript: str = Input( description="Transcript of the target audio file", @@ -188,8 +242,19 @@ class Predictor(BasePredictor): seed_everything(seed) - segments = self.transcribe_models[whisper_model].transcribe(str(orig_audio)) + whisper_model, whisper_model_size = whisper_model.split("-") + + if whisper_model == "whisper": + segments = self.transcribe_models_whisper[whisper_model_size].transcribe( + str(orig_audio) + ) + else: + segments = self.transcribe_models_whisperx[whisper_model_size].transcribe( + str(orig_audio) + ) + state = get_transcribe_state(segments) + whisper_transcript = state["transcript"].strip() if len(orig_transcript.strip()) == 0: @@ -204,25 +269,8 @@ class Predictor(BasePredictor): os.makedirs(temp_folder) filename = "orig_audio" - shutil.copy(orig_audio, f"{temp_folder}/{filename}.wav") + audio_fn = str(orig_audio) - with open(f"{temp_folder}/{filename}.txt", "w") as f: - f.write(orig_transcript) - - # run MFA to get the alignment - align_temp = f"{temp_folder}/mfa_alignments" - - command = f'/bin/bash -c "source /cog/miniconda/bin/activate && conda activate {ENV_NAME} && mfa align -v --clean -j 1 --output_format csv {temp_folder} english_us_arpa english_us_arpa {align_temp}"' - try: - subprocess.run(command, shell=True, check=True) - except subprocess.CalledProcessError as e: - print("Error:", e) - raise RuntimeError("Error running Alignment") - - print("Alignment done!") - - align_fn = f"{align_temp}/{filename}.csv" - audio_fn = f"{temp_folder}/{filename}.wav" info = torchaudio.info(audio_fn) audio_dur = info.num_frames / info.sample_rate @@ -262,7 +310,10 @@ class Predictor(BasePredictor): new_span_save = new_span orig_span_save = ",".join([str(item) for item in orig_span_save]) new_span_save = ",".join([str(item) for item in new_span_save]) - start, end = get_mask_interval(align_fn, orig_span_save, edit_type) + + start, end = get_mask_interval_from_word_bounds( + state["word_bounds"], orig_span_save, edit_type + ) # span in codec frames morphed_span = ( @@ -359,3 +410,23 @@ def find_closest_cut_off_word(word_bounds, cut_off_sec): break return i + + +def get_mask_interval_from_word_bounds(word_bounds, word_span_ind, editType): + tmp = word_span_ind.split(",") + s, e = int(tmp[0]), int(tmp[-1]) + start = None + for j, item in enumerate(word_bounds): + if j == s: + if editType == "insertion": + start = float(item["end"]) + else: + start = float(item["start"]) + if j == e: + if editType == "insertion": + end = float(item["start"]) + else: + end = float(item["end"]) + assert start != None + break + return (start, end) From 729d0ec69e90ae0e8e89ae35b4daec1c06722c0b Mon Sep 17 00:00:00 2001 From: chenxwh Date: Fri, 19 Apr 2024 10:48:15 +0000 Subject: [PATCH 6/8] update audiocraft install --- cog.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cog.yaml b/cog.yaml index a020931..5dd3888 100644 --- a/cog.yaml +++ b/cog.yaml @@ -17,7 +17,8 @@ build: - whisperx==3.1.1 - openai-whisper>=20231117 run: - - git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft + # - git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft + - pip install -e git+https://github.com/facebookresearch/audiocraft.git@f83babff6b5e97f75562127c4cc8122229c8f099#egg=audiocraft # use "git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft" instead if hits audiocraft import error - pip install "pydantic<2.0.0" - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth" "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth" From 2a2ee984b65fcf21e5d115e0cac7b94cf8b882c5 Mon Sep 17 00:00:00 2001 From: chenxwh Date: Fri, 19 Apr 2024 10:49:04 +0000 Subject: [PATCH 7/8] update audiocraft install --- cog.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cog.yaml b/cog.yaml index 5dd3888..f219cb0 100644 --- a/cog.yaml +++ b/cog.yaml @@ -18,7 +18,7 @@ build: - openai-whisper>=20231117 run: # - git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft - - pip install -e git+https://github.com/facebookresearch/audiocraft.git@f83babff6b5e97f75562127c4cc8122229c8f099#egg=audiocraft # use "git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft" instead if hits audiocraft import error + - pip install -e git+https://github.com/facebookresearch/audiocraft.git@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft # use "git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft" instead if hits audiocraft import error - pip install "pydantic<2.0.0" - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth" "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth" From 87f4fa5d21439f9951ffbc404178f7b86da304b8 Mon Sep 17 00:00:00 2001 From: chenxwh Date: Sun, 21 Apr 2024 22:30:56 +0000 Subject: [PATCH 8/8] update --- audiocraft | 1 - cog.yaml | 3 +- predict.py | 87 ++++++++++++++---------------------------------------- 3 files changed, 23 insertions(+), 68 deletions(-) delete mode 160000 audiocraft diff --git a/audiocraft b/audiocraft deleted file mode 160000 index 69fea8b..0000000 --- a/audiocraft +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85 diff --git a/cog.yaml b/cog.yaml index f219cb0..a020931 100644 --- a/cog.yaml +++ b/cog.yaml @@ -17,8 +17,7 @@ build: - whisperx==3.1.1 - openai-whisper>=20231117 run: - # - git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft - - pip install -e git+https://github.com/facebookresearch/audiocraft.git@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft # use "git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft" instead if hits audiocraft import error + - git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft - pip install "pydantic<2.0.0" - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth" "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth" diff --git a/predict.py b/predict.py index 4258076..951be42 100644 --- a/predict.py +++ b/predict.py @@ -10,8 +10,6 @@ import subprocess import torch import numpy as np import torchaudio -from whisper.model import Whisper, ModelDimensions -from whisper.tokenizer import get_tokenizer from cog import BasePredictor, Input, Path, BaseModel os.environ["USER"] = getpass.getuser() @@ -83,30 +81,6 @@ class WhisperxModel: return self.align_model.align(segments, audio_path) -class WhisperModel: - def __init__(self, model_cache, model_name="base.en", device="cuda"): - - # the model weights are cached from https://github.com/openai/whisper/blob/ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab/whisper/__init__.py#L17 - with open(f"{model_cache}/{model_name}.pt", "rb") as fp: - checkpoint = torch.load(fp, map_location="cpu") - dims = ModelDimensions(**checkpoint["dims"]) - self.model = Whisper(dims) - self.model.load_state_dict(checkpoint["model_state_dict"]) - self.model.to(device) - - tokenizer = get_tokenizer(multilingual=False) - self.supress_tokens = [-1] + [ - i - for i in range(tokenizer.eot) - if all(c in "0123456789" for c in tokenizer.decode([i]).removeprefix(" ")) - ] - - def transcribe(self, audio_path): - return self.model.transcribe( - audio_path, suppress_tokens=self.supress_tokens, word_timestamps=True - )["segments"] - - def download_weights(url, dest): start = time.time() print("downloading url: ", url) @@ -146,13 +120,9 @@ class Predictor(BasePredictor): self.text_tokenizer = TextTokenizer(backend="espeak") self.audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=self.device) - self.transcribe_models_whisper = { - k: WhisperModel(MODEL_CACHE, k, self.device) - for k in ["base.en", "small.en", "medium.en"] - } align_model = WhisperxAlignModel() - self.transcribe_models_whisperx = { + self.transcribe_models = { k: WhisperxModel(f"{MODEL_CACHE}/whisperx_{k.split('.')[0]}", align_model) for k in ["base.en", "small.en", "medium.en"] } @@ -174,24 +144,19 @@ class Predictor(BasePredictor): choices=["giga830M.pth", "giga330M.pth", "giga330M_TTSEnhanced.pth"], default="giga330M_TTSEnhanced.pth", ), - orig_audio: Path = Input( - description="Original audio file. WhisperX small.en model will be used for transcription" - ), + orig_audio: Path = Input(description="Original audio file"), orig_transcript: str = Input( - description="Optionally provide the transcript of the input audio. Leave it blank to use the whisper model below to generate the transcript. Inaccurate transcription may lead to error TTS or speech editing", + description="Optionally provide the transcript of the input audio. Leave it blank to use the WhisperX model below to generate the transcript. Inaccurate transcription may lead to error TTS or speech editing", default="", ), - whisper_model: str = Input( - description="If orig_transcript is not provided above, choose a Whisper or WhisperX model. WhisperX model contains extra alignment steps. Inaccurate transcription may lead to error TTS or speech editing. You can modify the generated transcript and provide it directly to ", + whisperx_model: str = Input( + description="If orig_transcript is not provided above, choose a WhisperX model for generating the transcript. Inaccurate transcription may lead to error TTS or speech editing. You can modify the generated transcript and provide it directly to orig_transcript above", choices=[ - "whisper-base.en", - "whisper-small.en", - "whisper-medium.en", - "whisperx-base.en", - "whisperx-small.en", - "whisperx-medium.en", + "base.en", + "small.en", + "medium.en", ], - default="whisper-base.en", + default="base.en", ), target_transcript: str = Input( description="Transcript of the target audio file", @@ -202,6 +167,7 @@ class Predictor(BasePredictor): ), kvcache: int = Input( description="Set to 0 to use less VRAM, but with slower inference", + choices=[0, 1], default=1, ), left_margin: float = Input( @@ -217,17 +183,15 @@ class Predictor(BasePredictor): default=1, ), top_p: float = Input( - description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens", - ge=0.0, - le=1.0, - default=0.8, + description="Default value for TTS is 0.9, and 0.8 for speech editing", + default=0.9, ), stop_repetition: int = Input( - default=-1, - description=" -1 means do not adjust prob of silence tokens. if there are long silence or unnaturally stretched words, increase sample_batch_size to 2, 3 or even 4", + default=3, + description="Default value for TTS is 3, and -1 for speech editing. -1 means do not adjust prob of silence tokens. if there are long silence or unnaturally stretched words, increase sample_batch_size to 2, 3 or even 4", ), sample_batch_size: int = Input( - description="The higher the number, the faster the output will be. Under the hood, the model will generate this many samples and choose the shortest one", + description="Default value for TTS is 4, and 1 for speech editing. The higher the number, the faster the output will be. Under the hood, the model will generate this many samples and choose the shortest one", default=4, ), seed: int = Input( @@ -242,16 +206,9 @@ class Predictor(BasePredictor): seed_everything(seed) - whisper_model, whisper_model_size = whisper_model.split("-") - - if whisper_model == "whisper": - segments = self.transcribe_models_whisper[whisper_model_size].transcribe( - str(orig_audio) - ) - else: - segments = self.transcribe_models_whisperx[whisper_model_size].transcribe( - str(orig_audio) - ) + segments = self.transcribe_models[whisperx_model].transcribe( + str(orig_audio) + ) state = get_transcribe_state(segments) @@ -290,8 +247,8 @@ class Predictor(BasePredictor): prompt_end_frame = int(cut_off_sec * info.sample_rate) idx = find_closest_cut_off_word(state["word_bounds"], cut_off_sec) - orig_transcript_until_cutoff_time = "".join( - [word_bound["word"] for word_bound in state["word_bounds"][:idx]] + orig_transcript_until_cutoff_time = " ".join( + [word_bound["word"] for word_bound in state["word_bounds"][: idx + 1]] ) else: edit_type = task.split("-")[-1] @@ -346,7 +303,7 @@ class Predictor(BasePredictor): self.audio_tokenizer, audio_fn, orig_transcript_until_cutoff_time.strip() - + "" + + " " + target_transcript.strip(), self.device, decode_config, @@ -427,6 +384,6 @@ def get_mask_interval_from_word_bounds(word_bounds, word_span_ind, editType): end = float(item["start"]) else: end = float(item["end"]) - assert start != None + assert start is not None break return (start, end)