Merge 19593c5ce0
into ddfef83331
This commit is contained in:
commit
1ef9ce5632
|
@ -0,0 +1,452 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import argparse
|
||||
import importlib
|
||||
|
||||
from data.tokenizer import TextTokenizer, AudioTokenizer
|
||||
|
||||
# The following requirements are for VoiceCraft inside inference_tts_scale.py
|
||||
try:
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchmetrics
|
||||
import numpy
|
||||
import tqdm
|
||||
import phonemizer
|
||||
import audiocraft
|
||||
except ImportError:
|
||||
print(
|
||||
"Pre-reqs not found. Installing numpy, torch, and audio dependencies.")
|
||||
subprocess.run(
|
||||
["pip", "install", "numpy", "torch==2.0.1", "torchaudio",
|
||||
"torchmetrics", "tqdm", "phonemizer"])
|
||||
|
||||
subprocess.run(["pip", "install", "-e",
|
||||
"git+https://github.com/facebookresearch/audiocraft.git"
|
||||
"@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft"])
|
||||
|
||||
from inference_tts_scale import inference_one_sample
|
||||
from models import voicecraft
|
||||
|
||||
description = """
|
||||
VoiceCraft Inference Text-to-Speech Demo
|
||||
This script demonstrates how to use the VoiceCraft model for text-to-speech synthesis.
|
||||
|
||||
Pre-Requirements:
|
||||
- Python 3.9.16
|
||||
- Conda (https://docs.conda.io/en/latest/miniconda.html)
|
||||
- FFmpeg
|
||||
- eSpeak NG
|
||||
|
||||
Usage:
|
||||
1. Prepare an audio file and its corresponding transcript.
|
||||
2. Run the script with the required command-line arguments:
|
||||
python voicecraft_tts_demo.py --audio <path_to_audio_file> --transcript <path_to_transcript_file>
|
||||
3. The generated audio files will be saved in the `./demo/generated_tts` directory.
|
||||
|
||||
Notes:
|
||||
- The script will download the required models automatically if they are not found in the `./pretrained_models` directory.
|
||||
- You can adjust the hyperparameters using command-line arguments to fine-tune the text-to-speech synthesis.
|
||||
"""
|
||||
|
||||
|
||||
def is_tool(name):
|
||||
"""Check whether `name` is on PATH and marked as executable."""
|
||||
return shutil.which(name) is not None
|
||||
|
||||
|
||||
def run_command(command, error_message):
|
||||
if command[0] == "source":
|
||||
# Handle the 'source' command separately using os.system()
|
||||
status = os.system(" ".join(command))
|
||||
if status != 0:
|
||||
print(error_message)
|
||||
sys.exit(1)
|
||||
else:
|
||||
try:
|
||||
subprocess.run(command, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error: {e}")
|
||||
print(error_message)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def install_linux_dependencies():
|
||||
if is_tool("apt-get"):
|
||||
# Debian, Ubuntu, and derivatives
|
||||
run_command(["sudo", "apt-get", "update"],
|
||||
"Failed to update package lists.")
|
||||
run_command(["sudo", "apt-get", "install", "-y", "git-core", "ffmpeg",
|
||||
"espeak-ng"],
|
||||
"Failed to install Linux dependencies.")
|
||||
elif is_tool("pacman"):
|
||||
# Arch Linux and derivatives
|
||||
run_command(["sudo", "pacman", "-Syu", "--noconfirm", "git", "ffmpeg",
|
||||
"espeak-ng"],
|
||||
"Failed to install Linux dependencies.")
|
||||
elif is_tool("dnf"):
|
||||
# Fedora and derivatives
|
||||
run_command(
|
||||
["sudo", "dnf", "install", "-y", "git", "ffmpeg", "espeak-ng"],
|
||||
"Failed to install Linux dependencies.")
|
||||
elif is_tool("yum"):
|
||||
# CentOS and derivatives
|
||||
run_command(
|
||||
["sudo", "yum", "install", "-y", "git", "ffmpeg", "espeak-ng"],
|
||||
"Failed to install Linux dependencies.")
|
||||
else:
|
||||
print(
|
||||
"Error: Unsupported Linux distribution. Please install the dependencies manually.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def install_macos_dependencies():
|
||||
if is_tool("brew"):
|
||||
packages = ["git", "ffmpeg", "espeak", "anaconda"]
|
||||
missing_packages = [package for package in packages if
|
||||
not is_tool(package)]
|
||||
|
||||
if missing_packages:
|
||||
run_command(["brew", "install"] + missing_packages,
|
||||
"Failed to install missing macOS dependencies.")
|
||||
else:
|
||||
print("All required packages are already installed.")
|
||||
|
||||
# Add Anaconda bin directory to PATH
|
||||
anaconda_bin_path = "/opt/homebrew/anaconda3/bin"
|
||||
os.environ["PATH"] = f"{anaconda_bin_path}:{os.environ['PATH']}"
|
||||
|
||||
# Update the shell configuration file (e.g., .bash_profile or .zshrc)
|
||||
shell_config_file = os.path.expanduser(
|
||||
"~/.bash_profile") # or "~/.zshrc" for zsh
|
||||
with open(shell_config_file, "a") as file:
|
||||
file.write(f'\nexport PATH="{anaconda_bin_path}:$PATH"\n')
|
||||
|
||||
else:
|
||||
print(
|
||||
"Error: Homebrew not found. Please install Homebrew and try again.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def install_dependencies():
|
||||
if sys.platform == "win32":
|
||||
print(description)
|
||||
print("Please install the required dependencies manually on Windows.")
|
||||
sys.exit(1)
|
||||
elif sys.platform == "darwin":
|
||||
install_macos_dependencies()
|
||||
elif sys.platform.startswith("linux"):
|
||||
install_linux_dependencies()
|
||||
else:
|
||||
print(f"Unsupported platform: {sys.platform}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def install_conda_dependencies():
|
||||
conda_packages = [
|
||||
"montreal-forced-aligner=2.2.17",
|
||||
"openfst=1.8.2",
|
||||
"kaldi=5.5.1068"
|
||||
]
|
||||
|
||||
run_command(
|
||||
["conda", "install", "-y", "-c", "conda-forge", "--solver",
|
||||
"classic"] + conda_packages,
|
||||
"Failed to install Conda packages.")
|
||||
|
||||
|
||||
def create_conda_environment():
|
||||
run_command(["conda", "create", "-y", "-n", "voicecraft", "python=3.9.16",
|
||||
"--solver", "classic"],
|
||||
"Failed to create Conda environment.")
|
||||
|
||||
# Initialize Conda for the current shell session
|
||||
conda_init_command = 'eval "$(conda shell.bash hook)"'
|
||||
os.system(conda_init_command)
|
||||
|
||||
bashrc_path = os.path.expanduser("~/.bashrc")
|
||||
if os.path.exists(bashrc_path):
|
||||
run_command(["source", bashrc_path],
|
||||
"Failed to source .bashrc.")
|
||||
else:
|
||||
print("Warning: ~/.bashrc not found. Skipping sourcing.")
|
||||
|
||||
# Activate the Conda environment
|
||||
activate_command = f"conda activate voicecraft"
|
||||
os.system(activate_command)
|
||||
|
||||
# Install any required dependencies in Conda env
|
||||
install_conda_dependencies()
|
||||
|
||||
|
||||
def install_python_dependencies():
|
||||
pip_packages = [
|
||||
"torch==2.0.1",
|
||||
"tensorboard==2.16.2",
|
||||
"phonemizer==3.2.1",
|
||||
"torchaudio==2.0.2",
|
||||
"datasets==2.16.0",
|
||||
"torchmetrics==0.11.1"
|
||||
]
|
||||
|
||||
run_command(["pip", "install"] + pip_packages,
|
||||
"Failed to install Python packages.")
|
||||
|
||||
run_command(["pip", "install", "-e",
|
||||
"git+https://github.com/facebookresearch/audiocraft.git"
|
||||
"@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft"],
|
||||
"Failed to install audiocraft package.")
|
||||
|
||||
|
||||
def download_models(ckpt_fn, encodec_fn):
|
||||
if not os.path.exists(ckpt_fn):
|
||||
run_command(["wget",
|
||||
f"https://huggingface.co/pyp1/VoiceCraft/resolve/main/{os.path.basename(ckpt_fn)}?download=true"],
|
||||
f"Failed to download {ckpt_fn}.")
|
||||
run_command(
|
||||
["mv", f"{os.path.basename(ckpt_fn)}?download=true", ckpt_fn],
|
||||
f"Failed to move {ckpt_fn}.")
|
||||
|
||||
if not os.path.exists(encodec_fn):
|
||||
run_command(["wget",
|
||||
"https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th"],
|
||||
f"Failed to download {encodec_fn}.")
|
||||
run_command(["mv", "encodec_4cb2048_giga.th", encodec_fn],
|
||||
f"Failed to move {encodec_fn}.")
|
||||
|
||||
|
||||
def check_python_dependencies():
|
||||
dependencies = [
|
||||
"torch",
|
||||
"torchaudio",
|
||||
"data.tokenizer",
|
||||
"models.voicecraft",
|
||||
"inference_tts_scale",
|
||||
"audiocraft",
|
||||
"phonemizer",
|
||||
"tensorboard"
|
||||
]
|
||||
|
||||
missing_dependencies = []
|
||||
for dependency in dependencies:
|
||||
try:
|
||||
importlib.import_module(dependency)
|
||||
except ImportError:
|
||||
missing_dependencies.append(dependency)
|
||||
|
||||
if missing_dependencies:
|
||||
print("Missing Python dependencies:", missing_dependencies)
|
||||
install_python_dependencies()
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description=description,
|
||||
formatter_class=argparse.RawTextHelpFormatter)
|
||||
parser.add_argument("-a", "--audio", required=True,
|
||||
help="Path to the input audio file used as a "
|
||||
"reference for the voice.")
|
||||
parser.add_argument("-t", "--transcript", required=True,
|
||||
help="Path to the text file containing the transcript "
|
||||
"to be synthesized.")
|
||||
parser.add_argument("--skip-install", "-s", action="store_true",
|
||||
help="Skip the installation of prerequisites.")
|
||||
parser.add_argument("--output_dir", default="./demo/generated_tts",
|
||||
help="Output directory where the generated audio "
|
||||
"files will be saved. Default: "
|
||||
"'./demo/generated_tts'")
|
||||
parser.add_argument("--cut-off-sec", type=float, default=3.0,
|
||||
help="Cut-off time in seconds for the audio prompt ("
|
||||
"hundredths of a second are acceptable). "
|
||||
"Default: 3.0")
|
||||
parser.add_argument("--left_margin", type=float, default=0.08,
|
||||
help="Left margin of the audio segment used for "
|
||||
"speech editing. This is not used for "
|
||||
"text-to-speech synthesis. Default: 0.08")
|
||||
parser.add_argument("--right_margin", type=float, default=0.08,
|
||||
help="Right margin of the audio segment used for "
|
||||
"speech editing. This is not used for "
|
||||
"text-to-speech synthesis. Default: 0.08")
|
||||
parser.add_argument("--codec_audio_sr", type=int, default=16000,
|
||||
help="Sample rate of the audio codec used for "
|
||||
"encoding and decoding. Default: 16000")
|
||||
parser.add_argument("--codec_sr", type=int, default=50,
|
||||
help="Sample rate of the codec used for encoding and "
|
||||
"decoding. Default: 50")
|
||||
parser.add_argument("--top_k", type=int, default=0,
|
||||
help="Top-k sampling parameter. It limits the number "
|
||||
"of highest probability tokens to consider "
|
||||
"during generation. A higher value (e.g., "
|
||||
"50) will result in more diverse but potentially "
|
||||
"less coherent speech, while a lower value ("
|
||||
"e.g., 1) will result in more conservative and "
|
||||
"repetitive speech. Setting it to 0 disables "
|
||||
"top-k sampling. Default: 0")
|
||||
parser.add_argument("--top_p", type=float, default=0.8,
|
||||
help="Top-p sampling parameter. It controls the "
|
||||
"diversity of the generated audio by truncating "
|
||||
"the least likely tokens whose cumulative "
|
||||
"probability exceeds 'p'. Lower values (e.g., "
|
||||
"0.5) will result in more conservative and "
|
||||
"repetitive speech, while higher values (e.g., "
|
||||
"0.9) will result in more diverse speech. "
|
||||
"Default: 0.8")
|
||||
parser.add_argument("--temperature", type=float, default=1.0,
|
||||
help="Sampling temperature. It controls the "
|
||||
"randomness of the generated speech. Higher "
|
||||
"values (e.g., 1.5) will result in more "
|
||||
"expressive and varied speech, while lower "
|
||||
"values (e.g., 0.5) will result in more "
|
||||
"monotonous and conservative speech. Default: 1.0")
|
||||
parser.add_argument("--kvcache", type=int, default=1,
|
||||
help="Key-value cache size used for caching "
|
||||
"intermediate results. A larger cache size may "
|
||||
"improve performance but consume more memory. "
|
||||
"Default: 1")
|
||||
parser.add_argument("--seed", type=int, default=1,
|
||||
help="Random seed for reproducibility. Use the same "
|
||||
"seed value to generate the same output for a "
|
||||
"given input. Default: 1")
|
||||
parser.add_argument("--stop_repetition", type=int, default=3,
|
||||
help="Stop repetition threshold. It controls the "
|
||||
"number of consecutive repetitions allowed in "
|
||||
"the generated speech. Lower values (e.g., "
|
||||
"1 or 2) will result in less repetitive speech "
|
||||
"but may also lead to abrupt stopping. Higher "
|
||||
"values (e.g., 4 or 5) will allow more "
|
||||
"repetitions. Default: 3")
|
||||
parser.add_argument("--sample_batch_size", type=int, default=4,
|
||||
help="Number of audio samples generated in parallel. "
|
||||
"Increasing this value may improve the quality "
|
||||
"of the generated speech by reducing long "
|
||||
"silences or unnaturally stretched words, "
|
||||
"but it will also increase memory usage. "
|
||||
"Default: 4")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
|
||||
if not args.skip_install:
|
||||
install_dependencies()
|
||||
create_conda_environment()
|
||||
check_python_dependencies()
|
||||
|
||||
orig_audio = args.audio
|
||||
orig_transcript = args.transcript
|
||||
output_dir = args.output_dir
|
||||
|
||||
# Create the output directory if it doesn't exist
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Hyperparameters for inference
|
||||
left_margin = args.left_margin
|
||||
right_margin = args.right_margin
|
||||
codec_audio_sr = args.codec_audio_sr
|
||||
codec_sr = args.codec_sr
|
||||
top_k = args.top_k
|
||||
top_p = args.top_p
|
||||
temperature = args.temperature
|
||||
kvcache = args.kvcache
|
||||
silence_tokens = [1388, 1898, 131]
|
||||
seed = args.seed
|
||||
stop_repetition = args.stop_repetition
|
||||
sample_batch_size = args.sample_batch_size
|
||||
|
||||
# Set the device based on available hardware
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
elif sys.platform == "darwin" and torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
# Move audio and transcript to temp folder
|
||||
temp_folder = "./demo/temp"
|
||||
os.makedirs(temp_folder, exist_ok=True)
|
||||
subprocess.run(["cp", orig_audio, temp_folder])
|
||||
filename = os.path.splitext(os.path.basename(orig_audio))[0]
|
||||
with open(f"{temp_folder}/{filename}.txt", "w") as f:
|
||||
f.write(orig_transcript)
|
||||
|
||||
# Run MFA to get the alignment
|
||||
align_temp = f"{temp_folder}/mfa_alignments"
|
||||
os.makedirs(align_temp, exist_ok=True)
|
||||
subprocess.run(
|
||||
["mfa", "model", "download", "dictionary", "english_us_arpa"])
|
||||
subprocess.run(["mfa", "model", "download", "acoustic", "english_us_arpa"])
|
||||
subprocess.run(
|
||||
["mfa", "align", "-v", "--clean", "-j", "1", "--output_format", "csv",
|
||||
temp_folder, "english_us_arpa", "english_us_arpa", align_temp])
|
||||
|
||||
audio_fn = f"{temp_folder}/{os.path.basename(orig_audio)}"
|
||||
transcript_fn = f"{temp_folder}/{filename}.txt"
|
||||
align_fn = f"{align_temp}/{filename}.csv"
|
||||
|
||||
# Decide which part of the audio to use as prompt based on forced alignment
|
||||
cut_off_sec = args.cut_off_sec
|
||||
|
||||
info = torchaudio.info(audio_fn)
|
||||
audio_dur = info.num_frames / info.sample_rate
|
||||
assert cut_off_sec < audio_dur, f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}"
|
||||
prompt_end_frame = int(cut_off_sec * info.sample_rate)
|
||||
|
||||
# Load model, tokenizer, and other necessary files
|
||||
voicecraft_name = "giga830M.pth"
|
||||
ckpt_fn = f"./pretrained_models/{voicecraft_name}"
|
||||
encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th"
|
||||
|
||||
if not os.path.exists(ckpt_fn):
|
||||
subprocess.run(["wget",
|
||||
f"https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}?download=true"])
|
||||
subprocess.run(["mv", f"{voicecraft_name}?download=true",
|
||||
f"./pretrained_models/{voicecraft_name}"])
|
||||
|
||||
if not os.path.exists(encodec_fn):
|
||||
subprocess.run(["wget",
|
||||
"https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th"])
|
||||
subprocess.run(["mv", "encodec_4cb2048_giga.th",
|
||||
"./pretrained_models/encodec_4cb2048_giga.th"])
|
||||
|
||||
ckpt = torch.load(ckpt_fn, map_location="cpu")
|
||||
model = voicecraft.VoiceCraft(ckpt["config"])
|
||||
model.load_state_dict(ckpt["model"])
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
phn2num = ckpt['phn2num']
|
||||
text_tokenizer = TextTokenizer(backend="espeak")
|
||||
audio_tokenizer = AudioTokenizer(
|
||||
signature=encodec_fn) # will also put the neural codec model on gpu
|
||||
|
||||
# Run the model to get the output
|
||||
decode_config = {
|
||||
'top_k': top_k, 'top_p': top_p, 'temperature': temperature,
|
||||
'stop_repetition': stop_repetition, 'kvcache': kvcache,
|
||||
"codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr,
|
||||
"silence_tokens": silence_tokens, "sample_batch_size": sample_batch_size
|
||||
}
|
||||
concated_audio, gen_audio = inference_one_sample(
|
||||
model, ckpt["config"], phn2num, text_tokenizer, audio_tokenizer,
|
||||
audio_fn, transcript_fn, device, decode_config, prompt_end_frame
|
||||
)
|
||||
|
||||
# Save segments for comparison
|
||||
concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()
|
||||
|
||||
# Save the audio
|
||||
seg_save_fn_gen = os.path.join(output_dir,
|
||||
f"{os.path.basename(orig_audio)[:-4]}_gen_seed{seed}.wav")
|
||||
seg_save_fn_concat = os.path.join(output_dir,
|
||||
f"{os.path.basename(orig_audio)[:-4]}_concat_seed{seed}.wav")
|
||||
|
||||
torchaudio.save(seg_save_fn_gen, gen_audio, codec_audio_sr)
|
||||
torchaudio.save(seg_save_fn_concat, concated_audio, codec_audio_sr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue