Compare commits
8 Commits
bb45dd5795
...
9ffb152332
Author | SHA1 | Date |
---|---|---|
Chenxi | 9ffb152332 | |
Chenxi | e3fc926ca4 | |
chenxwh | 0c6942fd2a | |
chenxwh | f649f9216b | |
Chenxi | 1e2f8391a7 | |
chenxwh | b8eca5a2d4 | |
chenxwh | 023d4b1c6c | |
chenxwh | 49a648fa54 |
Binary file not shown.
|
@ -0,0 +1,8 @@
|
|||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.0.1
|
||||
torchaudio==2.0.2
|
||||
xformers==0.0.22
|
||||
tensorboard==2.16.2
|
||||
phonemizer==3.2.1
|
||||
datasets==2.16.0
|
||||
torchmetrics==0.11.1
|
Binary file not shown.
|
@ -0,0 +1,8 @@
|
|||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.0.1
|
||||
torchaudio==2.0.2
|
||||
xformers==0.0.22
|
||||
tensorboard==2.16.2
|
||||
phonemizer==3.2.1
|
||||
datasets==2.16.0
|
||||
torchmetrics==0.11.1
|
Binary file not shown.
|
@ -0,0 +1,8 @@
|
|||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.0.1
|
||||
torchaudio==2.0.2
|
||||
xformers==0.0.22
|
||||
tensorboard==2.16.2
|
||||
phonemizer==3.2.1
|
||||
datasets==2.16.0
|
||||
torchmetrics==0.11.1
|
Binary file not shown.
|
@ -0,0 +1,8 @@
|
|||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.0.1
|
||||
torchaudio==2.0.2
|
||||
xformers==0.0.22
|
||||
tensorboard==2.16.2
|
||||
phonemizer==3.2.1
|
||||
datasets==2.16.0
|
||||
torchmetrics==0.11.1
|
Binary file not shown.
|
@ -0,0 +1,8 @@
|
|||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.0.1
|
||||
torchaudio==2.0.2
|
||||
xformers==0.0.22
|
||||
tensorboard==2.16.2
|
||||
phonemizer==3.2.1
|
||||
datasets==2.16.0
|
||||
torchmetrics==0.11.1
|
Binary file not shown.
|
@ -0,0 +1,8 @@
|
|||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.0.1
|
||||
torchaudio==2.0.2
|
||||
xformers==0.0.22
|
||||
tensorboard==2.16.2
|
||||
phonemizer==3.2.1
|
||||
datasets==2.16.0
|
||||
torchmetrics==0.11.1
|
|
@ -0,0 +1,17 @@
|
|||
# The .dockerignore file excludes files from the container build process.
|
||||
#
|
||||
# https://docs.docker.com/engine/reference/builder/#dockerignore-file
|
||||
|
||||
# Exclude Git files
|
||||
.git
|
||||
.github
|
||||
.gitignore
|
||||
|
||||
# Exclude Python cache files
|
||||
__pycache__
|
||||
.mypy_cache
|
||||
.pytest_cache
|
||||
.ruff_cache
|
||||
|
||||
# Exclude Python virtual environment
|
||||
/venv
|
|
@ -1,5 +1,6 @@
|
|||
# VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild
|
||||
[![Paper](https://img.shields.io/badge/arXiv-2301.12503-brightgreen.svg?style=flat-square)](https://jasonppy.github.io/assets/pdfs/VoiceCraft.pdf) [![githubio](https://img.shields.io/badge/GitHub.io-Audio_Samples-blue?logo=Github&style=flat-square)](https://jasonppy.github.io/VoiceCraft_web/) [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/pyp1/VoiceCraft_gradio) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1IOjpglQyMTO2C3Y94LD9FY0Ocn-RJRg6?usp=sharing)
|
||||
[![Paper](https://img.shields.io/badge/arXiv-2301.12503-brightgreen.svg?style=flat-square)](https://jasonppy.github.io/assets/pdfs/VoiceCraft.pdf) [![githubio](https://img.shields.io/badge/GitHub.io-Audio_Samples-blue?logo=Github&style=flat-square)](https://jasonppy.github.io/VoiceCraft_web/) [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/pyp1/VoiceCraft_gradio) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1IOjpglQyMTO2C3Y94LD9FY0Ocn-RJRg6?usp=sharing) [![Replicate](https://replicate.com/cjwbw/voicecraft/badge)](https://replicate.com/cjwbw/voicecraft)
|
||||
|
||||
|
||||
### TL;DR
|
||||
VoiceCraft is a token infilling neural codec language model, that achieves state-of-the-art performance on both **speech editing** and **zero-shot text-to-speech (TTS)** on in-the-wild data including audiobooks, internet videos, and podcasts.
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85
|
|
@ -0,0 +1,31 @@
|
|||
# Configuration for Cog ⚙️
|
||||
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
|
||||
|
||||
build:
|
||||
gpu: true
|
||||
system_packages:
|
||||
- "libgl1-mesa-glx"
|
||||
- "libglib2.0-0"
|
||||
- ffmpeg
|
||||
- espeak-ng
|
||||
python_version: "3.9.16"
|
||||
python_packages:
|
||||
- torch==2.0.1
|
||||
- torchaudio==2.0.2
|
||||
- xformers==0.0.22
|
||||
- tensorboard==2.16.2
|
||||
- phonemizer==3.2.1
|
||||
- datasets==2.16.0
|
||||
- torchmetrics==0.11.1
|
||||
run:
|
||||
- curl -O https://repo.anaconda.com/miniconda/Miniconda3-py310_23.3.1-0-Linux-x86_64.sh
|
||||
- bash Miniconda3-py310_23.3.1-0-Linux-x86_64.sh -b -p /cog/miniconda
|
||||
- /cog/miniconda/bin/conda init bash
|
||||
- /bin/bash -c "source /cog/miniconda/bin/activate && conda create -n myenv python=3.9.16 -y"
|
||||
- /bin/bash -c "source /cog/miniconda/bin/activate && conda activate myenv && conda install -c conda-forge montreal-forced-aligner=2.2.17 openfst=1.8.2 kaldi=5.5.1068 -y"
|
||||
- /bin/bash -c "source /cog/miniconda/bin/activate && conda activate myenv && mfa model download dictionary english_us_arpa && mfa model download acoustic english_us_arpa"
|
||||
- export PATH=/cog/miniconda/envs/myenv/bin:$PATH
|
||||
- git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft
|
||||
- pip install "pydantic<2.0.0"
|
||||
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
|
||||
predict: "predict.py:Predictor"
|
|
@ -0,0 +1,267 @@
|
|||
# Prediction interface for Cog ⚙️
|
||||
# https://github.com/replicate/cog/blob/main/docs/python.md
|
||||
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
import random
|
||||
import getpass
|
||||
import torch
|
||||
import torchaudio
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
os.environ["USER"] = getpass.getuser()
|
||||
|
||||
from data.tokenizer import (
|
||||
AudioTokenizer,
|
||||
TextTokenizer,
|
||||
)
|
||||
from cog import BasePredictor, Input, Path
|
||||
from models import voicecraft
|
||||
from inference_tts_scale import inference_one_sample
|
||||
from edit_utils import get_span
|
||||
from inference_speech_editing_scale import get_mask_interval
|
||||
from inference_speech_editing_scale import (
|
||||
inference_one_sample as inference_one_sample_editing,
|
||||
)
|
||||
|
||||
ENV_NAME = "myenv"
|
||||
|
||||
|
||||
MODEL_URL = "https://weights.replicate.delivery/default/VoiceCraft.tar"
|
||||
MODEL_CACHE = "model_cache"
|
||||
|
||||
|
||||
def download_weights(url, dest):
|
||||
start = time.time()
|
||||
print("downloading url: ", url)
|
||||
print("downloading to: ", dest)
|
||||
subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
|
||||
print("downloading took: ", time.time() - start)
|
||||
|
||||
|
||||
class Predictor(BasePredictor):
|
||||
def setup(self):
|
||||
"""Load the model into memory to make running multiple predictions efficient"""
|
||||
self.device = "cuda"
|
||||
|
||||
voicecraft_name = "giga830M.pth" # or giga330M.pth
|
||||
|
||||
if not os.path.exists(MODEL_CACHE):
|
||||
download_weights(MODEL_URL, MODEL_CACHE)
|
||||
|
||||
encodec_fn = f"{MODEL_CACHE}/encodec_4cb2048_giga.th"
|
||||
ckpt_fn = f"{MODEL_CACHE}/{voicecraft_name}"
|
||||
|
||||
self.ckpt = torch.load(ckpt_fn, map_location="cpu")
|
||||
self.model = voicecraft.VoiceCraft(self.ckpt["config"])
|
||||
self.model.load_state_dict(self.ckpt["model"])
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
self.phn2num = self.ckpt["phn2num"]
|
||||
|
||||
self.text_tokenizer = TextTokenizer(backend="espeak")
|
||||
self.audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=self.device)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
task: str = Input(
|
||||
description="Choose a task. For zero-shot text-to-speech, you also need to specify the cut_off_sec of the original audio to be used for zero-shot generation and the transcript until the cut_off_sec",
|
||||
choices=[
|
||||
"speech_editing-substitution",
|
||||
"speech_editing-insertion",
|
||||
"speech_editing-deletion",
|
||||
"zero-shot text-to-speech",
|
||||
],
|
||||
default="speech_editing-substitution",
|
||||
),
|
||||
orig_audio: Path = Input(description="Original audio file"),
|
||||
orig_transcript: str = Input(
|
||||
description="Transcript of the original audio file. You can use models such as https://replicate.com/openai/whisper and https://replicate.com/vaibhavs10/incredibly-fast-whisper to get the transcript (and modify it if it's not accurate)",
|
||||
),
|
||||
target_transcript: str = Input(
|
||||
description="Transcript of the target audio file",
|
||||
),
|
||||
cut_off_sec: float = Input(
|
||||
description="Valid/Required for zero-shot text-to-speech task. The first seconds of the original audio that are used for zero-shot text-to-speech (TTS). 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec",
|
||||
default=None,
|
||||
),
|
||||
orig_transcript_until_cutoff_time: str = Input(
|
||||
description="Valid/Required for zero-shot text-to-speech task. Transcript of the original audio file until the cut_off_sec specified above. This process will be improved and made automatically later",
|
||||
default=None,
|
||||
),
|
||||
temperature: float = Input(
|
||||
description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic",
|
||||
ge=0.01,
|
||||
le=5,
|
||||
default=1,
|
||||
),
|
||||
top_p: float = Input(
|
||||
description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
default=0.8,
|
||||
),
|
||||
stop_repetition: int = Input(
|
||||
default=-1,
|
||||
description=" -1 means do not adjust prob of silence tokens. if there are long silence or unnaturally strecthed words, increase sample_batch_size to 2, 3 or even 4",
|
||||
),
|
||||
sampling_rate: int = Input(
|
||||
description="Specify the sampling rate of the audio codec", default=16000
|
||||
),
|
||||
seed: int = Input(
|
||||
description="Random seed. Leave blank to randomize the seed", default=None
|
||||
),
|
||||
) -> Path:
|
||||
"""Run a single prediction on the model"""
|
||||
|
||||
if task == "zero-shot text-to-speech":
|
||||
assert (
|
||||
orig_transcript_until_cutoff_time is not None
|
||||
and cut_off_sec is not None
|
||||
), "Please provide cut_off_sec and orig_transcript_until_cutoff_time for zero-shot text-to-speech task."
|
||||
if seed is None:
|
||||
seed = int.from_bytes(os.urandom(2), "big")
|
||||
print(f"Using seed: {seed}")
|
||||
|
||||
seed_everything(seed)
|
||||
|
||||
temp_folder = "exp_dir"
|
||||
if os.path.exists(temp_folder):
|
||||
shutil.rmtree(temp_folder)
|
||||
|
||||
os.makedirs(temp_folder)
|
||||
|
||||
filename = "orig_audio"
|
||||
shutil.copy(orig_audio, f"{temp_folder}/{filename}.wav")
|
||||
|
||||
with open(f"{temp_folder}/{filename}.txt", "w") as f:
|
||||
f.write(orig_transcript)
|
||||
|
||||
# run MFA to get the alignment
|
||||
align_temp = f"{temp_folder}/mfa_alignments"
|
||||
|
||||
command = f'/bin/bash -c "source /cog/miniconda/bin/activate && conda activate {ENV_NAME} && mfa align -v --clean -j 1 --output_format csv {temp_folder} english_us_arpa english_us_arpa {align_temp}"'
|
||||
try:
|
||||
subprocess.run(command, shell=True, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("Error:", e)
|
||||
raise RuntimeError("Error running Alignment")
|
||||
|
||||
print("Alignment done!")
|
||||
|
||||
align_fn = f"{align_temp}/{filename}.csv"
|
||||
audio_fn = f"{temp_folder}/{filename}.wav"
|
||||
info = torchaudio.info(audio_fn)
|
||||
audio_dur = info.num_frames / info.sample_rate
|
||||
|
||||
# hyperparameters for inference
|
||||
left_margin = 0.08
|
||||
right_margin = 0.08
|
||||
codec_sr = 50
|
||||
top_k = 0
|
||||
silence_tokens = [1388, 1898, 131]
|
||||
kvcache = 1 if task == "zero-shot text-to-speech" else 0
|
||||
|
||||
sample_batch_size = 4 # NOTE: if the if there are long silence or unnaturally strecthed words, increase sample_batch_size to 5 or higher. What this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest. So if the speech rate of the generated is too fast change it to a smaller number.
|
||||
|
||||
if task == "zero-shot text-to-speech":
|
||||
assert (
|
||||
cut_off_sec < audio_dur
|
||||
), f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}"
|
||||
prompt_end_frame = int(cut_off_sec * info.sample_rate)
|
||||
|
||||
else:
|
||||
|
||||
edit_type = task.split("-")[-1]
|
||||
orig_span, new_span = get_span(
|
||||
orig_transcript, target_transcript, edit_type
|
||||
)
|
||||
if orig_span[0] > orig_span[1]:
|
||||
RuntimeError(f"example {audio_fn} failed")
|
||||
if orig_span[0] == orig_span[1]:
|
||||
orig_span_save = [orig_span[0]]
|
||||
else:
|
||||
orig_span_save = orig_span
|
||||
if new_span[0] == new_span[1]:
|
||||
new_span_save = [new_span[0]]
|
||||
else:
|
||||
new_span_save = new_span
|
||||
orig_span_save = ",".join([str(item) for item in orig_span_save])
|
||||
new_span_save = ",".join([str(item) for item in new_span_save])
|
||||
start, end = get_mask_interval(align_fn, orig_span_save, edit_type)
|
||||
|
||||
# span in codec frames
|
||||
morphed_span = (
|
||||
max(start - left_margin, 1 / codec_sr),
|
||||
min(end + right_margin, audio_dur),
|
||||
) # in seconds
|
||||
mask_interval = [
|
||||
[round(morphed_span[0] * codec_sr), round(morphed_span[1] * codec_sr)]
|
||||
]
|
||||
mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
|
||||
|
||||
decode_config = {
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"stop_repetition": stop_repetition,
|
||||
"kvcache": kvcache,
|
||||
"codec_audio_sr": sampling_rate,
|
||||
"codec_sr": codec_sr,
|
||||
"silence_tokens": silence_tokens,
|
||||
}
|
||||
|
||||
if task == "zero-shot text-to-speech":
|
||||
decode_config["sample_batch_size"] = sample_batch_size
|
||||
|
||||
concated_audio, gen_audio = inference_one_sample(
|
||||
self.model,
|
||||
self.ckpt["config"],
|
||||
self.phn2num,
|
||||
self.text_tokenizer,
|
||||
self.audio_tokenizer,
|
||||
audio_fn,
|
||||
orig_transcript_until_cutoff_time.strip()
|
||||
+ ""
|
||||
+ target_transcript.strip(),
|
||||
self.device,
|
||||
decode_config,
|
||||
prompt_end_frame,
|
||||
)
|
||||
|
||||
else:
|
||||
orig_audio, gen_audio = inference_one_sample_editing(
|
||||
self.model,
|
||||
self.ckpt["config"],
|
||||
self.phn2num,
|
||||
self.text_tokenizer,
|
||||
self.audio_tokenizer,
|
||||
audio_fn,
|
||||
target_transcript,
|
||||
mask_interval,
|
||||
self.device,
|
||||
decode_config,
|
||||
)
|
||||
|
||||
# save segments for comparison
|
||||
gen_audio = gen_audio[0].cpu()
|
||||
|
||||
out = "/tmp/out.wav"
|
||||
torchaudio.save(out, gen_audio, sampling_rate)
|
||||
return Path(out)
|
||||
|
||||
|
||||
def seed_everything(seed):
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
Loading…
Reference in New Issue