Compare commits
24 Commits
50c43075e7
...
e040ea7080
Author | SHA1 | Date |
---|---|---|
Jay | e040ea7080 | |
pyp_l40 | 4a3a8f11a7 | |
pyp_l40 | 8d1177149b | |
pyp_l40 | 4ff9930b8e | |
pyp_l40 | 96f6f9fc7a | |
chenxwh | ee3955d57e | |
chenxwh | 87f4fa5d21 | |
chenxwh | 2a2ee984b6 | |
chenxwh | 729d0ec69e | |
chenxwh | ef3dd8285b | |
chenxwh | 9746a1f60c | |
Chenxi | 4bd7b83b57 | |
Chenxi | 6e5382584c | |
chenxwh | 0da8ee4b7a | |
Chenxi | e3fc926ca4 | |
chenxwh | 0c6942fd2a | |
chenxwh | f649f9216b | |
JSTayco | 19593c5ce0 | |
JSTayco | 552d0bcd0d | |
Chenxi | 1e2f8391a7 | |
chenxwh | b8eca5a2d4 | |
chenxwh | 023d4b1c6c | |
chenxwh | 49a648fa54 | |
Jay | 66049a2526 |
|
@ -0,0 +1,17 @@
|
|||
# The .dockerignore file excludes files from the container build process.
|
||||
#
|
||||
# https://docs.docker.com/engine/reference/builder/#dockerignore-file
|
||||
|
||||
# Exclude Git files
|
||||
.git
|
||||
.github
|
||||
.gitignore
|
||||
|
||||
# Exclude Python cache files
|
||||
__pycache__
|
||||
.mypy_cache
|
||||
.pytest_cache
|
||||
.ruff_cache
|
||||
|
||||
# Exclude Python virtual environment
|
||||
/venv
|
|
@ -29,4 +29,5 @@ src/audiocraft
|
|||
!/demo/
|
||||
!/demo/*
|
||||
/demo/temp/*.txt
|
||||
!/demo/temp/84_121550_000074_000000.txt
|
||||
!/demo/temp/84_121550_000074_000000.txt
|
||||
.cog/tmp/*
|
|
@ -1,5 +1,6 @@
|
|||
# VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild
|
||||
[![Paper](https://img.shields.io/badge/arXiv-2301.12503-brightgreen.svg?style=flat-square)](https://jasonppy.github.io/assets/pdfs/VoiceCraft.pdf) [![githubio](https://img.shields.io/badge/GitHub.io-Audio_Samples-blue?logo=Github&style=flat-square)](https://jasonppy.github.io/VoiceCraft_web/) [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/pyp1/VoiceCraft_gradio) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1IOjpglQyMTO2C3Y94LD9FY0Ocn-RJRg6?usp=sharing)
|
||||
[![Paper](https://img.shields.io/badge/arXiv-2403.16973-brightgreen.svg?style=flat-square)](https://arxiv.org/pdf/2403.16973.pdf) [![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/pyp1/VoiceCraft_gradio) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1IOjpglQyMTO2C3Y94LD9FY0Ocn-RJRg6?usp=sharing) [![Replicate](https://replicate.com/cjwbw/voicecraft/badge)](https://replicate.com/cjwbw/voicecraft) [![YouTube demo](https://img.shields.io/youtube/views/eikybOi8iwU)](https://youtu.be/eikybOi8iwU) [![Demo page](https://img.shields.io/badge/Audio_Samples-blue?logo=Github&style=flat-square)](https://jasonppy.github.io/VoiceCraft_web/)
|
||||
|
||||
|
||||
### TL;DR
|
||||
VoiceCraft is a token infilling neural codec language model, that achieves state-of-the-art performance on both **speech editing** and **zero-shot text-to-speech (TTS)** on in-the-wild data including audiobooks, internet videos, and podcasts.
|
||||
|
@ -18,6 +19,8 @@ When you are inside the docker image or you have installed all dependencies, Che
|
|||
If you want to do model development such as training/finetuning, I recommend following [envrionment setup](#environment-setup) and [training](#training).
|
||||
|
||||
## News
|
||||
:star: 04/22/2024: 330M/830M TTS Enhanced Models are up [here](https://huggingface.co/pyp1), load them through [`gradio_app.py`](./gradio_app.py) or [`inference_tts.ipynb`](./inference_tts.ipynb)! Replicate demo is up, major thanks to [@chenxwh](https://github.com/chenxwh)!
|
||||
|
||||
:star: 04/11/2024: VoiceCraft Gradio is now available on HuggingFace Spaces [here](https://huggingface.co/spaces/pyp1/VoiceCraft_gradio)! Major thanks to [@zuev-stepan](https://github.com/zuev-stepan), [@Sewlell](https://github.com/Sewlell), [@pgsoar](https://github.com/pgosar) [@Ph0rk0z](https://github.com/Ph0rk0z).
|
||||
|
||||
:star: 04/05/2024: I finetuned giga330M with the TTS objective on gigaspeech and 1/5 of librilight. Weights are [here](https://huggingface.co/pyp1/VoiceCraft/tree/main). Make sure maximal prompt + generation length <= 16 seconds (due to our limited compute, we had to drop utterances longer than 16s in training data). Even stronger models forthcomming, stay tuned!
|
||||
|
@ -30,7 +33,7 @@ If you want to do model development such as training/finetuning, I recommend fol
|
|||
- [x] Inference demo for speech editing and TTS
|
||||
- [x] Training guidance
|
||||
- [x] RealEdit dataset and training manifest
|
||||
- [x] Model weights (giga330M.pth, giga830M.pth, and gigaHalfLibri330M_TTSEnhanced_max16s.pth)
|
||||
- [x] Model weights
|
||||
- [x] Better guidance on training/finetuning
|
||||
- [x] Colab notebooks
|
||||
- [x] HuggingFace Spaces demo
|
||||
|
@ -210,7 +213,7 @@ We thank Feiteng for his [VALL-E reproduction](https://github.com/lifeiteng/vall
|
|||
## Citation
|
||||
```
|
||||
@article{peng2024voicecraft,
|
||||
author = {Peng, Puyuan and Huang, Po-Yao and Li, Daniel and Mohamed, Abdelrahman and Harwath, David},
|
||||
author = {Peng, Puyuan and Huang, Po-Yao and Mohamed, Abdelrahman and Harwath, David},
|
||||
title = {VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild},
|
||||
journal = {arXiv},
|
||||
year = {2024},
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
# Configuration for Cog ⚙️
|
||||
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
|
||||
|
||||
build:
|
||||
gpu: true
|
||||
system_packages:
|
||||
- libgl1-mesa-glx
|
||||
- libglib2.0-0
|
||||
- ffmpeg
|
||||
- espeak-ng
|
||||
python_version: "3.11"
|
||||
python_packages:
|
||||
- torch==2.1.0
|
||||
- torchaudio==2.1.0
|
||||
- xformers
|
||||
- phonemizer==3.2.1
|
||||
- whisperx==3.1.1
|
||||
- openai-whisper>=20231117
|
||||
run:
|
||||
- git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft
|
||||
- pip install "pydantic<2.0.0"
|
||||
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
|
||||
- mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth" "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth"
|
||||
predict: "predict.py:Predictor"
|
|
@ -1,4 +1,6 @@
|
|||
import os
|
||||
import re
|
||||
from num2words import num2words
|
||||
import gradio as gr
|
||||
import torch
|
||||
import torchaudio
|
||||
|
@ -83,7 +85,7 @@ def load_models(whisper_backend_name, whisper_model_name, alignment_model_name,
|
|||
elif voicecraft_model_name == "830M":
|
||||
voicecraft_model_name = "giga830M"
|
||||
elif voicecraft_model_name == "330M_TTSEnhanced":
|
||||
voicecraft_model_name = "gigaHalfLibri330M_TTSEnhanced_max16s"
|
||||
voicecraft_model_name = "330M_TTSEnhanced"
|
||||
elif voicecraft_model_name == "830M_TTSEnhanced":
|
||||
voicecraft_model_name = "830M_TTSEnhanced"
|
||||
|
||||
|
@ -201,6 +203,15 @@ def get_output_audio(audio_tensors, codec_audio_sr):
|
|||
buffer.seek(0)
|
||||
return buffer.read()
|
||||
|
||||
def replace_numbers_with_words(sentence):
|
||||
sentence = re.sub(r'(\d+)', r' \1 ', sentence) # add spaces around numbers
|
||||
def replace_with_words(match):
|
||||
num = match.group(0)
|
||||
try:
|
||||
return num2words(num) # Convert numbers to words
|
||||
except:
|
||||
return num # In case num2words fails (unlikely with digits but just to be safe)
|
||||
return re.sub(r'\b\d+\b', replace_with_words, sentence) # Regular expression that matches numbers
|
||||
|
||||
def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature,
|
||||
stop_repetition, sample_batch_size, kvcache, silence_tokens,
|
||||
|
@ -213,6 +224,8 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p,
|
|||
raise gr.Error("Can't use smart transcript: whisper transcript not found")
|
||||
|
||||
seed_everything(seed)
|
||||
transcript = replace_numbers_with_words(transcript).replace(" ", " ").replace(" ", " ") # replace numbers with words, so that the phonemizer can do a better job
|
||||
|
||||
if mode == "Long TTS":
|
||||
if split_text == "Newline":
|
||||
sentences = transcript.split('\n')
|
||||
|
|
|
@ -4,3 +4,4 @@ openai-whisper>=20231117
|
|||
aeneas>=1.7.3.0
|
||||
whisperx>=3.1.1
|
||||
huggingface_hub==0.22.2
|
||||
num2words==0.5.13
|
||||
|
|
|
@ -71,7 +71,7 @@
|
|||
"# load model, encodec, and phn2num\n",
|
||||
"# # load model, tokenizer, and other necessary files\n",
|
||||
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
||||
"voicecraft_name=\"830M_TTSEnhanced.pth\" # or giga330M.pth, gigaHalfLibri330M_TTSEnhanced_max16s.pth, giga830M.pth\n",
|
||||
"voicecraft_name=\"830M_TTSEnhanced.pth\" # or giga330M.pth, 330M_TTSEnhanced.pth, giga830M.pth\n",
|
||||
"\n",
|
||||
"# the new way of loading the model, with huggingface, recommended\n",
|
||||
"from models import voicecraft\n",
|
||||
|
|
|
@ -0,0 +1,452 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import argparse
|
||||
import importlib
|
||||
|
||||
from data.tokenizer import TextTokenizer, AudioTokenizer
|
||||
|
||||
# The following requirements are for VoiceCraft inside inference_tts_scale.py
|
||||
try:
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchmetrics
|
||||
import numpy
|
||||
import tqdm
|
||||
import phonemizer
|
||||
import audiocraft
|
||||
except ImportError:
|
||||
print(
|
||||
"Pre-reqs not found. Installing numpy, torch, and audio dependencies.")
|
||||
subprocess.run(
|
||||
["pip", "install", "numpy", "torch==2.0.1", "torchaudio",
|
||||
"torchmetrics", "tqdm", "phonemizer"])
|
||||
|
||||
subprocess.run(["pip", "install", "-e",
|
||||
"git+https://github.com/facebookresearch/audiocraft.git"
|
||||
"@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft"])
|
||||
|
||||
from inference_tts_scale import inference_one_sample
|
||||
from models import voicecraft
|
||||
|
||||
description = """
|
||||
VoiceCraft Inference Text-to-Speech Demo
|
||||
This script demonstrates how to use the VoiceCraft model for text-to-speech synthesis.
|
||||
|
||||
Pre-Requirements:
|
||||
- Python 3.9.16
|
||||
- Conda (https://docs.conda.io/en/latest/miniconda.html)
|
||||
- FFmpeg
|
||||
- eSpeak NG
|
||||
|
||||
Usage:
|
||||
1. Prepare an audio file and its corresponding transcript.
|
||||
2. Run the script with the required command-line arguments:
|
||||
python voicecraft_tts_demo.py --audio <path_to_audio_file> --transcript <path_to_transcript_file>
|
||||
3. The generated audio files will be saved in the `./demo/generated_tts` directory.
|
||||
|
||||
Notes:
|
||||
- The script will download the required models automatically if they are not found in the `./pretrained_models` directory.
|
||||
- You can adjust the hyperparameters using command-line arguments to fine-tune the text-to-speech synthesis.
|
||||
"""
|
||||
|
||||
|
||||
def is_tool(name):
|
||||
"""Check whether `name` is on PATH and marked as executable."""
|
||||
return shutil.which(name) is not None
|
||||
|
||||
|
||||
def run_command(command, error_message):
|
||||
if command[0] == "source":
|
||||
# Handle the 'source' command separately using os.system()
|
||||
status = os.system(" ".join(command))
|
||||
if status != 0:
|
||||
print(error_message)
|
||||
sys.exit(1)
|
||||
else:
|
||||
try:
|
||||
subprocess.run(command, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error: {e}")
|
||||
print(error_message)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def install_linux_dependencies():
|
||||
if is_tool("apt-get"):
|
||||
# Debian, Ubuntu, and derivatives
|
||||
run_command(["sudo", "apt-get", "update"],
|
||||
"Failed to update package lists.")
|
||||
run_command(["sudo", "apt-get", "install", "-y", "git-core", "ffmpeg",
|
||||
"espeak-ng"],
|
||||
"Failed to install Linux dependencies.")
|
||||
elif is_tool("pacman"):
|
||||
# Arch Linux and derivatives
|
||||
run_command(["sudo", "pacman", "-Syu", "--noconfirm", "git", "ffmpeg",
|
||||
"espeak-ng"],
|
||||
"Failed to install Linux dependencies.")
|
||||
elif is_tool("dnf"):
|
||||
# Fedora and derivatives
|
||||
run_command(
|
||||
["sudo", "dnf", "install", "-y", "git", "ffmpeg", "espeak-ng"],
|
||||
"Failed to install Linux dependencies.")
|
||||
elif is_tool("yum"):
|
||||
# CentOS and derivatives
|
||||
run_command(
|
||||
["sudo", "yum", "install", "-y", "git", "ffmpeg", "espeak-ng"],
|
||||
"Failed to install Linux dependencies.")
|
||||
else:
|
||||
print(
|
||||
"Error: Unsupported Linux distribution. Please install the dependencies manually.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def install_macos_dependencies():
|
||||
if is_tool("brew"):
|
||||
packages = ["git", "ffmpeg", "espeak", "anaconda"]
|
||||
missing_packages = [package for package in packages if
|
||||
not is_tool(package)]
|
||||
|
||||
if missing_packages:
|
||||
run_command(["brew", "install"] + missing_packages,
|
||||
"Failed to install missing macOS dependencies.")
|
||||
else:
|
||||
print("All required packages are already installed.")
|
||||
|
||||
# Add Anaconda bin directory to PATH
|
||||
anaconda_bin_path = "/opt/homebrew/anaconda3/bin"
|
||||
os.environ["PATH"] = f"{anaconda_bin_path}:{os.environ['PATH']}"
|
||||
|
||||
# Update the shell configuration file (e.g., .bash_profile or .zshrc)
|
||||
shell_config_file = os.path.expanduser(
|
||||
"~/.bash_profile") # or "~/.zshrc" for zsh
|
||||
with open(shell_config_file, "a") as file:
|
||||
file.write(f'\nexport PATH="{anaconda_bin_path}:$PATH"\n')
|
||||
|
||||
else:
|
||||
print(
|
||||
"Error: Homebrew not found. Please install Homebrew and try again.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def install_dependencies():
|
||||
if sys.platform == "win32":
|
||||
print(description)
|
||||
print("Please install the required dependencies manually on Windows.")
|
||||
sys.exit(1)
|
||||
elif sys.platform == "darwin":
|
||||
install_macos_dependencies()
|
||||
elif sys.platform.startswith("linux"):
|
||||
install_linux_dependencies()
|
||||
else:
|
||||
print(f"Unsupported platform: {sys.platform}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def install_conda_dependencies():
|
||||
conda_packages = [
|
||||
"montreal-forced-aligner=2.2.17",
|
||||
"openfst=1.8.2",
|
||||
"kaldi=5.5.1068"
|
||||
]
|
||||
|
||||
run_command(
|
||||
["conda", "install", "-y", "-c", "conda-forge", "--solver",
|
||||
"classic"] + conda_packages,
|
||||
"Failed to install Conda packages.")
|
||||
|
||||
|
||||
def create_conda_environment():
|
||||
run_command(["conda", "create", "-y", "-n", "voicecraft", "python=3.9.16",
|
||||
"--solver", "classic"],
|
||||
"Failed to create Conda environment.")
|
||||
|
||||
# Initialize Conda for the current shell session
|
||||
conda_init_command = 'eval "$(conda shell.bash hook)"'
|
||||
os.system(conda_init_command)
|
||||
|
||||
bashrc_path = os.path.expanduser("~/.bashrc")
|
||||
if os.path.exists(bashrc_path):
|
||||
run_command(["source", bashrc_path],
|
||||
"Failed to source .bashrc.")
|
||||
else:
|
||||
print("Warning: ~/.bashrc not found. Skipping sourcing.")
|
||||
|
||||
# Activate the Conda environment
|
||||
activate_command = f"conda activate voicecraft"
|
||||
os.system(activate_command)
|
||||
|
||||
# Install any required dependencies in Conda env
|
||||
install_conda_dependencies()
|
||||
|
||||
|
||||
def install_python_dependencies():
|
||||
pip_packages = [
|
||||
"torch==2.0.1",
|
||||
"tensorboard==2.16.2",
|
||||
"phonemizer==3.2.1",
|
||||
"torchaudio==2.0.2",
|
||||
"datasets==2.16.0",
|
||||
"torchmetrics==0.11.1"
|
||||
]
|
||||
|
||||
run_command(["pip", "install"] + pip_packages,
|
||||
"Failed to install Python packages.")
|
||||
|
||||
run_command(["pip", "install", "-e",
|
||||
"git+https://github.com/facebookresearch/audiocraft.git"
|
||||
"@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft"],
|
||||
"Failed to install audiocraft package.")
|
||||
|
||||
|
||||
def download_models(ckpt_fn, encodec_fn):
|
||||
if not os.path.exists(ckpt_fn):
|
||||
run_command(["wget",
|
||||
f"https://huggingface.co/pyp1/VoiceCraft/resolve/main/{os.path.basename(ckpt_fn)}?download=true"],
|
||||
f"Failed to download {ckpt_fn}.")
|
||||
run_command(
|
||||
["mv", f"{os.path.basename(ckpt_fn)}?download=true", ckpt_fn],
|
||||
f"Failed to move {ckpt_fn}.")
|
||||
|
||||
if not os.path.exists(encodec_fn):
|
||||
run_command(["wget",
|
||||
"https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th"],
|
||||
f"Failed to download {encodec_fn}.")
|
||||
run_command(["mv", "encodec_4cb2048_giga.th", encodec_fn],
|
||||
f"Failed to move {encodec_fn}.")
|
||||
|
||||
|
||||
def check_python_dependencies():
|
||||
dependencies = [
|
||||
"torch",
|
||||
"torchaudio",
|
||||
"data.tokenizer",
|
||||
"models.voicecraft",
|
||||
"inference_tts_scale",
|
||||
"audiocraft",
|
||||
"phonemizer",
|
||||
"tensorboard"
|
||||
]
|
||||
|
||||
missing_dependencies = []
|
||||
for dependency in dependencies:
|
||||
try:
|
||||
importlib.import_module(dependency)
|
||||
except ImportError:
|
||||
missing_dependencies.append(dependency)
|
||||
|
||||
if missing_dependencies:
|
||||
print("Missing Python dependencies:", missing_dependencies)
|
||||
install_python_dependencies()
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description=description,
|
||||
formatter_class=argparse.RawTextHelpFormatter)
|
||||
parser.add_argument("-a", "--audio", required=True,
|
||||
help="Path to the input audio file used as a "
|
||||
"reference for the voice.")
|
||||
parser.add_argument("-t", "--transcript", required=True,
|
||||
help="Path to the text file containing the transcript "
|
||||
"to be synthesized.")
|
||||
parser.add_argument("--skip-install", "-s", action="store_true",
|
||||
help="Skip the installation of prerequisites.")
|
||||
parser.add_argument("--output_dir", default="./demo/generated_tts",
|
||||
help="Output directory where the generated audio "
|
||||
"files will be saved. Default: "
|
||||
"'./demo/generated_tts'")
|
||||
parser.add_argument("--cut-off-sec", type=float, default=3.0,
|
||||
help="Cut-off time in seconds for the audio prompt ("
|
||||
"hundredths of a second are acceptable). "
|
||||
"Default: 3.0")
|
||||
parser.add_argument("--left_margin", type=float, default=0.08,
|
||||
help="Left margin of the audio segment used for "
|
||||
"speech editing. This is not used for "
|
||||
"text-to-speech synthesis. Default: 0.08")
|
||||
parser.add_argument("--right_margin", type=float, default=0.08,
|
||||
help="Right margin of the audio segment used for "
|
||||
"speech editing. This is not used for "
|
||||
"text-to-speech synthesis. Default: 0.08")
|
||||
parser.add_argument("--codec_audio_sr", type=int, default=16000,
|
||||
help="Sample rate of the audio codec used for "
|
||||
"encoding and decoding. Default: 16000")
|
||||
parser.add_argument("--codec_sr", type=int, default=50,
|
||||
help="Sample rate of the codec used for encoding and "
|
||||
"decoding. Default: 50")
|
||||
parser.add_argument("--top_k", type=int, default=0,
|
||||
help="Top-k sampling parameter. It limits the number "
|
||||
"of highest probability tokens to consider "
|
||||
"during generation. A higher value (e.g., "
|
||||
"50) will result in more diverse but potentially "
|
||||
"less coherent speech, while a lower value ("
|
||||
"e.g., 1) will result in more conservative and "
|
||||
"repetitive speech. Setting it to 0 disables "
|
||||
"top-k sampling. Default: 0")
|
||||
parser.add_argument("--top_p", type=float, default=0.8,
|
||||
help="Top-p sampling parameter. It controls the "
|
||||
"diversity of the generated audio by truncating "
|
||||
"the least likely tokens whose cumulative "
|
||||
"probability exceeds 'p'. Lower values (e.g., "
|
||||
"0.5) will result in more conservative and "
|
||||
"repetitive speech, while higher values (e.g., "
|
||||
"0.9) will result in more diverse speech. "
|
||||
"Default: 0.8")
|
||||
parser.add_argument("--temperature", type=float, default=1.0,
|
||||
help="Sampling temperature. It controls the "
|
||||
"randomness of the generated speech. Higher "
|
||||
"values (e.g., 1.5) will result in more "
|
||||
"expressive and varied speech, while lower "
|
||||
"values (e.g., 0.5) will result in more "
|
||||
"monotonous and conservative speech. Default: 1.0")
|
||||
parser.add_argument("--kvcache", type=int, default=1,
|
||||
help="Key-value cache size used for caching "
|
||||
"intermediate results. A larger cache size may "
|
||||
"improve performance but consume more memory. "
|
||||
"Default: 1")
|
||||
parser.add_argument("--seed", type=int, default=1,
|
||||
help="Random seed for reproducibility. Use the same "
|
||||
"seed value to generate the same output for a "
|
||||
"given input. Default: 1")
|
||||
parser.add_argument("--stop_repetition", type=int, default=3,
|
||||
help="Stop repetition threshold. It controls the "
|
||||
"number of consecutive repetitions allowed in "
|
||||
"the generated speech. Lower values (e.g., "
|
||||
"1 or 2) will result in less repetitive speech "
|
||||
"but may also lead to abrupt stopping. Higher "
|
||||
"values (e.g., 4 or 5) will allow more "
|
||||
"repetitions. Default: 3")
|
||||
parser.add_argument("--sample_batch_size", type=int, default=4,
|
||||
help="Number of audio samples generated in parallel. "
|
||||
"Increasing this value may improve the quality "
|
||||
"of the generated speech by reducing long "
|
||||
"silences or unnaturally stretched words, "
|
||||
"but it will also increase memory usage. "
|
||||
"Default: 4")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
|
||||
if not args.skip_install:
|
||||
install_dependencies()
|
||||
create_conda_environment()
|
||||
check_python_dependencies()
|
||||
|
||||
orig_audio = args.audio
|
||||
orig_transcript = args.transcript
|
||||
output_dir = args.output_dir
|
||||
|
||||
# Create the output directory if it doesn't exist
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Hyperparameters for inference
|
||||
left_margin = args.left_margin
|
||||
right_margin = args.right_margin
|
||||
codec_audio_sr = args.codec_audio_sr
|
||||
codec_sr = args.codec_sr
|
||||
top_k = args.top_k
|
||||
top_p = args.top_p
|
||||
temperature = args.temperature
|
||||
kvcache = args.kvcache
|
||||
silence_tokens = [1388, 1898, 131]
|
||||
seed = args.seed
|
||||
stop_repetition = args.stop_repetition
|
||||
sample_batch_size = args.sample_batch_size
|
||||
|
||||
# Set the device based on available hardware
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
elif sys.platform == "darwin" and torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
# Move audio and transcript to temp folder
|
||||
temp_folder = "./demo/temp"
|
||||
os.makedirs(temp_folder, exist_ok=True)
|
||||
subprocess.run(["cp", orig_audio, temp_folder])
|
||||
filename = os.path.splitext(os.path.basename(orig_audio))[0]
|
||||
with open(f"{temp_folder}/{filename}.txt", "w") as f:
|
||||
f.write(orig_transcript)
|
||||
|
||||
# Run MFA to get the alignment
|
||||
align_temp = f"{temp_folder}/mfa_alignments"
|
||||
os.makedirs(align_temp, exist_ok=True)
|
||||
subprocess.run(
|
||||
["mfa", "model", "download", "dictionary", "english_us_arpa"])
|
||||
subprocess.run(["mfa", "model", "download", "acoustic", "english_us_arpa"])
|
||||
subprocess.run(
|
||||
["mfa", "align", "-v", "--clean", "-j", "1", "--output_format", "csv",
|
||||
temp_folder, "english_us_arpa", "english_us_arpa", align_temp])
|
||||
|
||||
audio_fn = f"{temp_folder}/{os.path.basename(orig_audio)}"
|
||||
transcript_fn = f"{temp_folder}/{filename}.txt"
|
||||
align_fn = f"{align_temp}/{filename}.csv"
|
||||
|
||||
# Decide which part of the audio to use as prompt based on forced alignment
|
||||
cut_off_sec = args.cut_off_sec
|
||||
|
||||
info = torchaudio.info(audio_fn)
|
||||
audio_dur = info.num_frames / info.sample_rate
|
||||
assert cut_off_sec < audio_dur, f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}"
|
||||
prompt_end_frame = int(cut_off_sec * info.sample_rate)
|
||||
|
||||
# Load model, tokenizer, and other necessary files
|
||||
voicecraft_name = "giga830M.pth"
|
||||
ckpt_fn = f"./pretrained_models/{voicecraft_name}"
|
||||
encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th"
|
||||
|
||||
if not os.path.exists(ckpt_fn):
|
||||
subprocess.run(["wget",
|
||||
f"https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}?download=true"])
|
||||
subprocess.run(["mv", f"{voicecraft_name}?download=true",
|
||||
f"./pretrained_models/{voicecraft_name}"])
|
||||
|
||||
if not os.path.exists(encodec_fn):
|
||||
subprocess.run(["wget",
|
||||
"https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th"])
|
||||
subprocess.run(["mv", "encodec_4cb2048_giga.th",
|
||||
"./pretrained_models/encodec_4cb2048_giga.th"])
|
||||
|
||||
ckpt = torch.load(ckpt_fn, map_location="cpu")
|
||||
model = voicecraft.VoiceCraft(ckpt["config"])
|
||||
model.load_state_dict(ckpt["model"])
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
phn2num = ckpt['phn2num']
|
||||
text_tokenizer = TextTokenizer(backend="espeak")
|
||||
audio_tokenizer = AudioTokenizer(
|
||||
signature=encodec_fn) # will also put the neural codec model on gpu
|
||||
|
||||
# Run the model to get the output
|
||||
decode_config = {
|
||||
'top_k': top_k, 'top_p': top_p, 'temperature': temperature,
|
||||
'stop_repetition': stop_repetition, 'kvcache': kvcache,
|
||||
"codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr,
|
||||
"silence_tokens": silence_tokens, "sample_batch_size": sample_batch_size
|
||||
}
|
||||
concated_audio, gen_audio = inference_one_sample(
|
||||
model, ckpt["config"], phn2num, text_tokenizer, audio_tokenizer,
|
||||
audio_fn, transcript_fn, device, decode_config, prompt_end_frame
|
||||
)
|
||||
|
||||
# Save segments for comparison
|
||||
concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()
|
||||
|
||||
# Save the audio
|
||||
seg_save_fn_gen = os.path.join(output_dir,
|
||||
f"{os.path.basename(orig_audio)[:-4]}_gen_seed{seed}.wav")
|
||||
seg_save_fn_concat = os.path.join(output_dir,
|
||||
f"{os.path.basename(orig_audio)[:-4]}_concat_seed{seed}.wav")
|
||||
|
||||
torchaudio.save(seg_save_fn_gen, gen_audio, codec_audio_sr)
|
||||
torchaudio.save(seg_save_fn_concat, concated_audio, codec_audio_sr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,389 @@
|
|||
# Prediction interface for Cog ⚙️
|
||||
# https://github.com/replicate/cog/blob/main/docs/python.md
|
||||
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
import getpass
|
||||
import shutil
|
||||
import subprocess
|
||||
import torch
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
from cog import BasePredictor, Input, Path, BaseModel
|
||||
|
||||
os.environ["USER"] = getpass.getuser()
|
||||
|
||||
from data.tokenizer import (
|
||||
AudioTokenizer,
|
||||
TextTokenizer,
|
||||
)
|
||||
from models import voicecraft
|
||||
from inference_tts_scale import inference_one_sample
|
||||
from edit_utils import get_span
|
||||
from inference_speech_editing_scale import (
|
||||
inference_one_sample as inference_one_sample_editing,
|
||||
)
|
||||
|
||||
|
||||
MODEL_URL = "https://weights.replicate.delivery/default/pyp1/VoiceCraft-models.tar" # all the models are cached and uploaded to replicate.delivery for faster booting
|
||||
MODEL_CACHE = "model_cache"
|
||||
|
||||
|
||||
class ModelOutput(BaseModel):
|
||||
whisper_transcript_orig_audio: str
|
||||
generated_audio: Path
|
||||
|
||||
|
||||
class WhisperxAlignModel:
|
||||
def __init__(self):
|
||||
from whisperx import load_align_model
|
||||
|
||||
self.model, self.metadata = load_align_model(
|
||||
language_code="en", device="cuda:0"
|
||||
)
|
||||
|
||||
def align(self, segments, audio_path):
|
||||
from whisperx import align, load_audio
|
||||
|
||||
audio = load_audio(audio_path)
|
||||
return align(
|
||||
segments,
|
||||
self.model,
|
||||
self.metadata,
|
||||
audio,
|
||||
device="cuda:0",
|
||||
return_char_alignments=False,
|
||||
)["segments"]
|
||||
|
||||
|
||||
class WhisperxModel:
|
||||
def __init__(self, model_name, align_model: WhisperxAlignModel, device="cuda"):
|
||||
from whisperx import load_model
|
||||
|
||||
# the model weights are cached from Systran/faster-whisper-base.en etc
|
||||
self.model = load_model(
|
||||
model_name,
|
||||
device,
|
||||
asr_options={
|
||||
"suppress_numerals": True,
|
||||
"max_new_tokens": None,
|
||||
"clip_timestamps": None,
|
||||
"hallucination_silence_threshold": None,
|
||||
},
|
||||
)
|
||||
self.align_model = align_model
|
||||
|
||||
def transcribe(self, audio_path):
|
||||
segments = self.model.transcribe(audio_path, language="en", batch_size=8)[
|
||||
"segments"
|
||||
]
|
||||
return self.align_model.align(segments, audio_path)
|
||||
|
||||
|
||||
def download_weights(url, dest):
|
||||
start = time.time()
|
||||
print("downloading url: ", url)
|
||||
print("downloading to: ", dest)
|
||||
subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
|
||||
print("downloading took: ", time.time() - start)
|
||||
|
||||
|
||||
class Predictor(BasePredictor):
|
||||
def setup(self):
|
||||
"""Load the model into memory to make running multiple predictions efficient"""
|
||||
self.device = "cuda"
|
||||
|
||||
if not os.path.exists(MODEL_CACHE):
|
||||
download_weights(MODEL_URL, MODEL_CACHE)
|
||||
|
||||
encodec_fn = f"{MODEL_CACHE}/encodec_4cb2048_giga.th"
|
||||
self.models, self.ckpt, self.phn2num = {}, {}, {}
|
||||
for voicecraft_name in [
|
||||
"giga830M.pth",
|
||||
"giga330M.pth",
|
||||
"gigaHalfLibri330M_TTSEnhanced_max16s.pth",
|
||||
]:
|
||||
ckpt_fn = f"{MODEL_CACHE}/{voicecraft_name}"
|
||||
|
||||
self.ckpt[voicecraft_name] = torch.load(ckpt_fn, map_location="cpu")
|
||||
self.models[voicecraft_name] = voicecraft.VoiceCraft(
|
||||
self.ckpt[voicecraft_name]["config"]
|
||||
)
|
||||
self.models[voicecraft_name].load_state_dict(
|
||||
self.ckpt[voicecraft_name]["model"]
|
||||
)
|
||||
self.models[voicecraft_name].to(self.device)
|
||||
self.models[voicecraft_name].eval()
|
||||
|
||||
self.phn2num[voicecraft_name] = self.ckpt[voicecraft_name]["phn2num"]
|
||||
|
||||
self.text_tokenizer = TextTokenizer(backend="espeak")
|
||||
self.audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=self.device)
|
||||
|
||||
align_model = WhisperxAlignModel()
|
||||
self.transcribe_models = {
|
||||
k: WhisperxModel(f"{MODEL_CACHE}/whisperx_{k.split('.')[0]}", align_model)
|
||||
for k in ["base.en", "small.en", "medium.en"]
|
||||
}
|
||||
|
||||
def predict(
|
||||
self,
|
||||
task: str = Input(
|
||||
description="Choose a task",
|
||||
choices=[
|
||||
"speech_editing-substitution",
|
||||
"speech_editing-insertion",
|
||||
"speech_editing-deletion",
|
||||
"zero-shot text-to-speech",
|
||||
],
|
||||
default="zero-shot text-to-speech",
|
||||
),
|
||||
voicecraft_model: str = Input(
|
||||
description="Choose a model",
|
||||
choices=["giga830M.pth", "giga330M.pth", "giga330M_TTSEnhanced.pth"],
|
||||
default="giga330M_TTSEnhanced.pth",
|
||||
),
|
||||
orig_audio: Path = Input(description="Original audio file"),
|
||||
orig_transcript: str = Input(
|
||||
description="Optionally provide the transcript of the input audio. Leave it blank to use the WhisperX model below to generate the transcript. Inaccurate transcription may lead to error TTS or speech editing",
|
||||
default="",
|
||||
),
|
||||
whisperx_model: str = Input(
|
||||
description="If orig_transcript is not provided above, choose a WhisperX model for generating the transcript. Inaccurate transcription may lead to error TTS or speech editing. You can modify the generated transcript and provide it directly to orig_transcript above",
|
||||
choices=[
|
||||
"base.en",
|
||||
"small.en",
|
||||
"medium.en",
|
||||
],
|
||||
default="base.en",
|
||||
),
|
||||
target_transcript: str = Input(
|
||||
description="Transcript of the target audio file",
|
||||
),
|
||||
cut_off_sec: float = Input(
|
||||
description="Only used for for zero-shot text-to-speech task. The first seconds of the original audio that are used for zero-shot text-to-speech. 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec",
|
||||
default=3.01,
|
||||
),
|
||||
kvcache: int = Input(
|
||||
description="Set to 0 to use less VRAM, but with slower inference",
|
||||
choices=[0, 1],
|
||||
default=1,
|
||||
),
|
||||
left_margin: float = Input(
|
||||
description="Margin to the left of the editing segment",
|
||||
default=0.08,
|
||||
),
|
||||
right_margin: float = Input(
|
||||
description="Margin to the right of the editing segment",
|
||||
default=0.08,
|
||||
),
|
||||
temperature: float = Input(
|
||||
description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic. Do not recommend to change",
|
||||
default=1,
|
||||
),
|
||||
top_p: float = Input(
|
||||
description="Default value for TTS is 0.9, and 0.8 for speech editing",
|
||||
default=0.9,
|
||||
),
|
||||
stop_repetition: int = Input(
|
||||
default=3,
|
||||
description="Default value for TTS is 3, and -1 for speech editing. -1 means do not adjust prob of silence tokens. if there are long silence or unnaturally stretched words, increase sample_batch_size to 2, 3 or even 4",
|
||||
),
|
||||
sample_batch_size: int = Input(
|
||||
description="Default value for TTS is 4, and 1 for speech editing. The higher the number, the faster the output will be. Under the hood, the model will generate this many samples and choose the shortest one",
|
||||
default=4,
|
||||
),
|
||||
seed: int = Input(
|
||||
description="Random seed. Leave blank to randomize the seed", default=None
|
||||
),
|
||||
) -> ModelOutput:
|
||||
"""Run a single prediction on the model"""
|
||||
|
||||
if seed is None:
|
||||
seed = int.from_bytes(os.urandom(2), "big")
|
||||
print(f"Using seed: {seed}")
|
||||
|
||||
seed_everything(seed)
|
||||
|
||||
segments = self.transcribe_models[whisperx_model].transcribe(
|
||||
str(orig_audio)
|
||||
)
|
||||
|
||||
state = get_transcribe_state(segments)
|
||||
|
||||
whisper_transcript = state["transcript"].strip()
|
||||
|
||||
if len(orig_transcript.strip()) == 0:
|
||||
orig_transcript = whisper_transcript
|
||||
|
||||
print(f"The transcript from the Whisper model: {whisper_transcript}")
|
||||
|
||||
temp_folder = "exp_dir"
|
||||
if os.path.exists(temp_folder):
|
||||
shutil.rmtree(temp_folder)
|
||||
|
||||
os.makedirs(temp_folder)
|
||||
|
||||
filename = "orig_audio"
|
||||
audio_fn = str(orig_audio)
|
||||
|
||||
info = torchaudio.info(audio_fn)
|
||||
audio_dur = info.num_frames / info.sample_rate
|
||||
|
||||
# hyperparameters for inference
|
||||
codec_audio_sr = 16000
|
||||
codec_sr = 50
|
||||
top_k = 0
|
||||
silence_tokens = [1388, 1898, 131]
|
||||
|
||||
if voicecraft_model == "giga330M_TTSEnhanced.pth":
|
||||
voicecraft_model = "gigaHalfLibri330M_TTSEnhanced_max16s.pth"
|
||||
|
||||
if task == "zero-shot text-to-speech":
|
||||
assert (
|
||||
cut_off_sec < audio_dur
|
||||
), f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}"
|
||||
prompt_end_frame = int(cut_off_sec * info.sample_rate)
|
||||
|
||||
idx = find_closest_cut_off_word(state["word_bounds"], cut_off_sec)
|
||||
orig_transcript_until_cutoff_time = " ".join(
|
||||
[word_bound["word"] for word_bound in state["word_bounds"][: idx + 1]]
|
||||
)
|
||||
else:
|
||||
edit_type = task.split("-")[-1]
|
||||
orig_span, new_span = get_span(
|
||||
orig_transcript, target_transcript, edit_type
|
||||
)
|
||||
if orig_span[0] > orig_span[1]:
|
||||
RuntimeError(f"example {audio_fn} failed")
|
||||
if orig_span[0] == orig_span[1]:
|
||||
orig_span_save = [orig_span[0]]
|
||||
else:
|
||||
orig_span_save = orig_span
|
||||
if new_span[0] == new_span[1]:
|
||||
new_span_save = [new_span[0]]
|
||||
else:
|
||||
new_span_save = new_span
|
||||
orig_span_save = ",".join([str(item) for item in orig_span_save])
|
||||
new_span_save = ",".join([str(item) for item in new_span_save])
|
||||
|
||||
start, end = get_mask_interval_from_word_bounds(
|
||||
state["word_bounds"], orig_span_save, edit_type
|
||||
)
|
||||
|
||||
# span in codec frames
|
||||
morphed_span = (
|
||||
max(start - left_margin, 1 / codec_sr),
|
||||
min(end + right_margin, audio_dur),
|
||||
) # in seconds
|
||||
mask_interval = [
|
||||
[round(morphed_span[0] * codec_sr), round(morphed_span[1] * codec_sr)]
|
||||
]
|
||||
mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
|
||||
|
||||
decode_config = {
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"stop_repetition": stop_repetition,
|
||||
"kvcache": kvcache,
|
||||
"codec_audio_sr": codec_audio_sr,
|
||||
"codec_sr": codec_sr,
|
||||
"silence_tokens": silence_tokens,
|
||||
}
|
||||
|
||||
if task == "zero-shot text-to-speech":
|
||||
decode_config["sample_batch_size"] = sample_batch_size
|
||||
_, gen_audio = inference_one_sample(
|
||||
self.models[voicecraft_model],
|
||||
self.ckpt[voicecraft_model]["config"],
|
||||
self.phn2num[voicecraft_model],
|
||||
self.text_tokenizer,
|
||||
self.audio_tokenizer,
|
||||
audio_fn,
|
||||
orig_transcript_until_cutoff_time.strip()
|
||||
+ " "
|
||||
+ target_transcript.strip(),
|
||||
self.device,
|
||||
decode_config,
|
||||
prompt_end_frame,
|
||||
)
|
||||
else:
|
||||
_, gen_audio = inference_one_sample_editing(
|
||||
self.models[voicecraft_model],
|
||||
self.ckpt[voicecraft_model]["config"],
|
||||
self.phn2num[voicecraft_model],
|
||||
self.text_tokenizer,
|
||||
self.audio_tokenizer,
|
||||
audio_fn,
|
||||
target_transcript,
|
||||
mask_interval,
|
||||
self.device,
|
||||
decode_config,
|
||||
)
|
||||
|
||||
# save segments for comparison
|
||||
gen_audio = gen_audio[0].cpu()
|
||||
|
||||
out = "/tmp/out.wav"
|
||||
torchaudio.save(out, gen_audio, codec_audio_sr)
|
||||
return ModelOutput(
|
||||
generated_audio=Path(out), whisper_transcript_orig_audio=whisper_transcript
|
||||
)
|
||||
|
||||
|
||||
def seed_everything(seed):
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def get_transcribe_state(segments):
|
||||
words_info = [word_info for segment in segments for word_info in segment["words"]]
|
||||
return {
|
||||
"transcript": " ".join([segment["text"].strip() for segment in segments]),
|
||||
"word_bounds": [
|
||||
{"word": word["word"], "start": word["start"], "end": word["end"]}
|
||||
for word in words_info
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def find_closest_cut_off_word(word_bounds, cut_off_sec):
|
||||
min_distance = float("inf")
|
||||
|
||||
for i, word_bound in enumerate(word_bounds):
|
||||
distance = abs(word_bound["start"] - cut_off_sec)
|
||||
|
||||
if distance < min_distance:
|
||||
min_distance = distance
|
||||
|
||||
if word_bound["end"] > cut_off_sec:
|
||||
break
|
||||
|
||||
return i
|
||||
|
||||
|
||||
def get_mask_interval_from_word_bounds(word_bounds, word_span_ind, editType):
|
||||
tmp = word_span_ind.split(",")
|
||||
s, e = int(tmp[0]), int(tmp[-1])
|
||||
start = None
|
||||
for j, item in enumerate(word_bounds):
|
||||
if j == s:
|
||||
if editType == "insertion":
|
||||
start = float(item["end"])
|
||||
else:
|
||||
start = float(item["start"])
|
||||
if j == e:
|
||||
if editType == "insertion":
|
||||
end = float(item["start"])
|
||||
else:
|
||||
end = float(item["end"])
|
||||
assert start is not None
|
||||
break
|
||||
return (start, end)
|
Loading…
Reference in New Issue