This commit is contained in:
Jay 2024-04-17 16:05:02 +08:00 committed by GitHub
commit 7e39208666
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 452 additions and 0 deletions

452
inference_tts.py Normal file
View File

@ -0,0 +1,452 @@
#!/usr/bin/env python3
import os
import shutil
import subprocess
import sys
import argparse
import importlib
from data.tokenizer import TextTokenizer, AudioTokenizer
# The following requirements are for VoiceCraft inside inference_tts_scale.py
try:
import torch
import torchaudio
import torchmetrics
import numpy
import tqdm
import phonemizer
import audiocraft
except ImportError:
print(
"Pre-reqs not found. Installing numpy, torch, and audio dependencies.")
subprocess.run(
["pip", "install", "numpy", "torch==2.0.1", "torchaudio",
"torchmetrics", "tqdm", "phonemizer"])
subprocess.run(["pip", "install", "-e",
"git+https://github.com/facebookresearch/audiocraft.git"
"@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft"])
from inference_tts_scale import inference_one_sample
from models import voicecraft
description = """
VoiceCraft Inference Text-to-Speech Demo
This script demonstrates how to use the VoiceCraft model for text-to-speech synthesis.
Pre-Requirements:
- Python 3.9.16
- Conda (https://docs.conda.io/en/latest/miniconda.html)
- FFmpeg
- eSpeak NG
Usage:
1. Prepare an audio file and its corresponding transcript.
2. Run the script with the required command-line arguments:
python voicecraft_tts_demo.py --audio <path_to_audio_file> --transcript <path_to_transcript_file>
3. The generated audio files will be saved in the `./demo/generated_tts` directory.
Notes:
- The script will download the required models automatically if they are not found in the `./pretrained_models` directory.
- You can adjust the hyperparameters using command-line arguments to fine-tune the text-to-speech synthesis.
"""
def is_tool(name):
"""Check whether `name` is on PATH and marked as executable."""
return shutil.which(name) is not None
def run_command(command, error_message):
if command[0] == "source":
# Handle the 'source' command separately using os.system()
status = os.system(" ".join(command))
if status != 0:
print(error_message)
sys.exit(1)
else:
try:
subprocess.run(command, check=True)
except subprocess.CalledProcessError as e:
print(f"Error: {e}")
print(error_message)
sys.exit(1)
def install_linux_dependencies():
if is_tool("apt-get"):
# Debian, Ubuntu, and derivatives
run_command(["sudo", "apt-get", "update"],
"Failed to update package lists.")
run_command(["sudo", "apt-get", "install", "-y", "git-core", "ffmpeg",
"espeak-ng"],
"Failed to install Linux dependencies.")
elif is_tool("pacman"):
# Arch Linux and derivatives
run_command(["sudo", "pacman", "-Syu", "--noconfirm", "git", "ffmpeg",
"espeak-ng"],
"Failed to install Linux dependencies.")
elif is_tool("dnf"):
# Fedora and derivatives
run_command(
["sudo", "dnf", "install", "-y", "git", "ffmpeg", "espeak-ng"],
"Failed to install Linux dependencies.")
elif is_tool("yum"):
# CentOS and derivatives
run_command(
["sudo", "yum", "install", "-y", "git", "ffmpeg", "espeak-ng"],
"Failed to install Linux dependencies.")
else:
print(
"Error: Unsupported Linux distribution. Please install the dependencies manually.")
sys.exit(1)
def install_macos_dependencies():
if is_tool("brew"):
packages = ["git", "ffmpeg", "espeak", "anaconda"]
missing_packages = [package for package in packages if
not is_tool(package)]
if missing_packages:
run_command(["brew", "install"] + missing_packages,
"Failed to install missing macOS dependencies.")
else:
print("All required packages are already installed.")
# Add Anaconda bin directory to PATH
anaconda_bin_path = "/opt/homebrew/anaconda3/bin"
os.environ["PATH"] = f"{anaconda_bin_path}:{os.environ['PATH']}"
# Update the shell configuration file (e.g., .bash_profile or .zshrc)
shell_config_file = os.path.expanduser(
"~/.bash_profile") # or "~/.zshrc" for zsh
with open(shell_config_file, "a") as file:
file.write(f'\nexport PATH="{anaconda_bin_path}:$PATH"\n')
else:
print(
"Error: Homebrew not found. Please install Homebrew and try again.")
sys.exit(1)
def install_dependencies():
if sys.platform == "win32":
print(description)
print("Please install the required dependencies manually on Windows.")
sys.exit(1)
elif sys.platform == "darwin":
install_macos_dependencies()
elif sys.platform.startswith("linux"):
install_linux_dependencies()
else:
print(f"Unsupported platform: {sys.platform}")
sys.exit(1)
def install_conda_dependencies():
conda_packages = [
"montreal-forced-aligner=2.2.17",
"openfst=1.8.2",
"kaldi=5.5.1068"
]
run_command(
["conda", "install", "-y", "-c", "conda-forge", "--solver",
"classic"] + conda_packages,
"Failed to install Conda packages.")
def create_conda_environment():
run_command(["conda", "create", "-y", "-n", "voicecraft", "python=3.9.16",
"--solver", "classic"],
"Failed to create Conda environment.")
# Initialize Conda for the current shell session
conda_init_command = 'eval "$(conda shell.bash hook)"'
os.system(conda_init_command)
bashrc_path = os.path.expanduser("~/.bashrc")
if os.path.exists(bashrc_path):
run_command(["source", bashrc_path],
"Failed to source .bashrc.")
else:
print("Warning: ~/.bashrc not found. Skipping sourcing.")
# Activate the Conda environment
activate_command = f"conda activate voicecraft"
os.system(activate_command)
# Install any required dependencies in Conda env
install_conda_dependencies()
def install_python_dependencies():
pip_packages = [
"torch==2.0.1",
"tensorboard==2.16.2",
"phonemizer==3.2.1",
"torchaudio==2.0.2",
"datasets==2.16.0",
"torchmetrics==0.11.1"
]
run_command(["pip", "install"] + pip_packages,
"Failed to install Python packages.")
run_command(["pip", "install", "-e",
"git+https://github.com/facebookresearch/audiocraft.git"
"@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft"],
"Failed to install audiocraft package.")
def download_models(ckpt_fn, encodec_fn):
if not os.path.exists(ckpt_fn):
run_command(["wget",
f"https://huggingface.co/pyp1/VoiceCraft/resolve/main/{os.path.basename(ckpt_fn)}?download=true"],
f"Failed to download {ckpt_fn}.")
run_command(
["mv", f"{os.path.basename(ckpt_fn)}?download=true", ckpt_fn],
f"Failed to move {ckpt_fn}.")
if not os.path.exists(encodec_fn):
run_command(["wget",
"https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th"],
f"Failed to download {encodec_fn}.")
run_command(["mv", "encodec_4cb2048_giga.th", encodec_fn],
f"Failed to move {encodec_fn}.")
def check_python_dependencies():
dependencies = [
"torch",
"torchaudio",
"data.tokenizer",
"models.voicecraft",
"inference_tts_scale",
"audiocraft",
"phonemizer",
"tensorboard"
]
missing_dependencies = []
for dependency in dependencies:
try:
importlib.import_module(dependency)
except ImportError:
missing_dependencies.append(dependency)
if missing_dependencies:
print("Missing Python dependencies:", missing_dependencies)
install_python_dependencies()
def parse_arguments():
parser = argparse.ArgumentParser(description=description,
formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("-a", "--audio", required=True,
help="Path to the input audio file used as a "
"reference for the voice.")
parser.add_argument("-t", "--transcript", required=True,
help="Path to the text file containing the transcript "
"to be synthesized.")
parser.add_argument("--skip-install", "-s", action="store_true",
help="Skip the installation of prerequisites.")
parser.add_argument("--output_dir", default="./demo/generated_tts",
help="Output directory where the generated audio "
"files will be saved. Default: "
"'./demo/generated_tts'")
parser.add_argument("--cut-off-sec", type=float, default=3.0,
help="Cut-off time in seconds for the audio prompt ("
"hundredths of a second are acceptable). "
"Default: 3.0")
parser.add_argument("--left_margin", type=float, default=0.08,
help="Left margin of the audio segment used for "
"speech editing. This is not used for "
"text-to-speech synthesis. Default: 0.08")
parser.add_argument("--right_margin", type=float, default=0.08,
help="Right margin of the audio segment used for "
"speech editing. This is not used for "
"text-to-speech synthesis. Default: 0.08")
parser.add_argument("--codec_audio_sr", type=int, default=16000,
help="Sample rate of the audio codec used for "
"encoding and decoding. Default: 16000")
parser.add_argument("--codec_sr", type=int, default=50,
help="Sample rate of the codec used for encoding and "
"decoding. Default: 50")
parser.add_argument("--top_k", type=int, default=0,
help="Top-k sampling parameter. It limits the number "
"of highest probability tokens to consider "
"during generation. A higher value (e.g., "
"50) will result in more diverse but potentially "
"less coherent speech, while a lower value ("
"e.g., 1) will result in more conservative and "
"repetitive speech. Setting it to 0 disables "
"top-k sampling. Default: 0")
parser.add_argument("--top_p", type=float, default=0.8,
help="Top-p sampling parameter. It controls the "
"diversity of the generated audio by truncating "
"the least likely tokens whose cumulative "
"probability exceeds 'p'. Lower values (e.g., "
"0.5) will result in more conservative and "
"repetitive speech, while higher values (e.g., "
"0.9) will result in more diverse speech. "
"Default: 0.8")
parser.add_argument("--temperature", type=float, default=1.0,
help="Sampling temperature. It controls the "
"randomness of the generated speech. Higher "
"values (e.g., 1.5) will result in more "
"expressive and varied speech, while lower "
"values (e.g., 0.5) will result in more "
"monotonous and conservative speech. Default: 1.0")
parser.add_argument("--kvcache", type=int, default=1,
help="Key-value cache size used for caching "
"intermediate results. A larger cache size may "
"improve performance but consume more memory. "
"Default: 1")
parser.add_argument("--seed", type=int, default=1,
help="Random seed for reproducibility. Use the same "
"seed value to generate the same output for a "
"given input. Default: 1")
parser.add_argument("--stop_repetition", type=int, default=3,
help="Stop repetition threshold. It controls the "
"number of consecutive repetitions allowed in "
"the generated speech. Lower values (e.g., "
"1 or 2) will result in less repetitive speech "
"but may also lead to abrupt stopping. Higher "
"values (e.g., 4 or 5) will allow more "
"repetitions. Default: 3")
parser.add_argument("--sample_batch_size", type=int, default=4,
help="Number of audio samples generated in parallel. "
"Increasing this value may improve the quality "
"of the generated speech by reducing long "
"silences or unnaturally stretched words, "
"but it will also increase memory usage. "
"Default: 4")
return parser.parse_args()
def main():
args = parse_arguments()
if not args.skip_install:
install_dependencies()
create_conda_environment()
check_python_dependencies()
orig_audio = args.audio
orig_transcript = args.transcript
output_dir = args.output_dir
# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Hyperparameters for inference
left_margin = args.left_margin
right_margin = args.right_margin
codec_audio_sr = args.codec_audio_sr
codec_sr = args.codec_sr
top_k = args.top_k
top_p = args.top_p
temperature = args.temperature
kvcache = args.kvcache
silence_tokens = [1388, 1898, 131]
seed = args.seed
stop_repetition = args.stop_repetition
sample_batch_size = args.sample_batch_size
# Set the device based on available hardware
if torch.cuda.is_available():
device = "cuda"
elif sys.platform == "darwin" and torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
# Move audio and transcript to temp folder
temp_folder = "./demo/temp"
os.makedirs(temp_folder, exist_ok=True)
subprocess.run(["cp", orig_audio, temp_folder])
filename = os.path.splitext(os.path.basename(orig_audio))[0]
with open(f"{temp_folder}/{filename}.txt", "w") as f:
f.write(orig_transcript)
# Run MFA to get the alignment
align_temp = f"{temp_folder}/mfa_alignments"
os.makedirs(align_temp, exist_ok=True)
subprocess.run(
["mfa", "model", "download", "dictionary", "english_us_arpa"])
subprocess.run(["mfa", "model", "download", "acoustic", "english_us_arpa"])
subprocess.run(
["mfa", "align", "-v", "--clean", "-j", "1", "--output_format", "csv",
temp_folder, "english_us_arpa", "english_us_arpa", align_temp])
audio_fn = f"{temp_folder}/{os.path.basename(orig_audio)}"
transcript_fn = f"{temp_folder}/{filename}.txt"
align_fn = f"{align_temp}/{filename}.csv"
# Decide which part of the audio to use as prompt based on forced alignment
cut_off_sec = args.cut_off_sec
info = torchaudio.info(audio_fn)
audio_dur = info.num_frames / info.sample_rate
assert cut_off_sec < audio_dur, f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}"
prompt_end_frame = int(cut_off_sec * info.sample_rate)
# Load model, tokenizer, and other necessary files
voicecraft_name = "giga830M.pth"
ckpt_fn = f"./pretrained_models/{voicecraft_name}"
encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th"
if not os.path.exists(ckpt_fn):
subprocess.run(["wget",
f"https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}?download=true"])
subprocess.run(["mv", f"{voicecraft_name}?download=true",
f"./pretrained_models/{voicecraft_name}"])
if not os.path.exists(encodec_fn):
subprocess.run(["wget",
"https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th"])
subprocess.run(["mv", "encodec_4cb2048_giga.th",
"./pretrained_models/encodec_4cb2048_giga.th"])
ckpt = torch.load(ckpt_fn, map_location="cpu")
model = voicecraft.VoiceCraft(ckpt["config"])
model.load_state_dict(ckpt["model"])
model.to(device)
model.eval()
phn2num = ckpt['phn2num']
text_tokenizer = TextTokenizer(backend="espeak")
audio_tokenizer = AudioTokenizer(
signature=encodec_fn) # will also put the neural codec model on gpu
# Run the model to get the output
decode_config = {
'top_k': top_k, 'top_p': top_p, 'temperature': temperature,
'stop_repetition': stop_repetition, 'kvcache': kvcache,
"codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr,
"silence_tokens": silence_tokens, "sample_batch_size": sample_batch_size
}
concated_audio, gen_audio = inference_one_sample(
model, ckpt["config"], phn2num, text_tokenizer, audio_tokenizer,
audio_fn, transcript_fn, device, decode_config, prompt_end_frame
)
# Save segments for comparison
concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()
# Save the audio
seg_save_fn_gen = os.path.join(output_dir,
f"{os.path.basename(orig_audio)[:-4]}_gen_seed{seed}.wav")
seg_save_fn_concat = os.path.join(output_dir,
f"{os.path.basename(orig_audio)[:-4]}_concat_seed{seed}.wav")
torchaudio.save(seg_save_fn_gen, gen_audio, codec_audio_sr)
torchaudio.save(seg_save_fn_concat, concated_audio, codec_audio_sr)
if __name__ == "__main__":
main()