diff --git a/inference_tts.py b/inference_tts.py new file mode 100644 index 0000000..48d2e90 --- /dev/null +++ b/inference_tts.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python3 + +import os +import subprocess +import sys +import argparse +import importlib + +description = """ +VoiceCraft Inference Text-to-Speech Demo +This script demonstrates how to use the VoiceCraft model for text-to-speech synthesis. + +Pre-Requirements: +- Python 3.9.16 +- Conda (https://docs.conda.io/en/latest/miniconda.html) +- FFmpeg +- eSpeak NG + +Usage: +1. Prepare an audio file and its corresponding transcript. +2. Run the script with the required command-line arguments: + python voicecraft_tts_demo.py --audio --transcript +3. The generated audio files will be saved in the `./demo/generated_tts` directory. + +Notes: +- The script will download the required models automatically if they are not found in the `./pretrained_models` directory. +- You can adjust the hyperparameters using command-line arguments to fine-tune the text-to-speech synthesis. +""" + +def is_tool(name): + """Check whether `name` is on PATH and marked as executable.""" + return shutil.which(name) is not None + +def run_command(command, error_message): + try: + subprocess.run(command, check=True) + except subprocess.CalledProcessError as e: + print(f"Error: {e}") + print(error_message) + sys.exit(1) + +def install_linux_dependencies(): + if is_tool("apt-get"): + # Debian, Ubuntu, and derivatives + run_command(["sudo", "apt-get", "update"], + "Failed to update package lists.") + run_command(["sudo", "apt-get", "install", "-y", "git-core", "ffmpeg", "espeak-ng"], + "Failed to install Linux dependencies.") + elif is_tool("pacman"): + # Arch Linux and derivatives + run_command(["sudo", "pacman", "-Syu", "--noconfirm", "git", "ffmpeg", "espeak-ng"], + "Failed to install Linux dependencies.") + elif is_tool("dnf"): + # Fedora and derivatives + run_command(["sudo", "dnf", "install", "-y", "git", "ffmpeg", "espeak-ng"], + "Failed to install Linux dependencies.") + elif is_tool("yum"): + # CentOS and derivatives + run_command(["sudo", "yum", "install", "-y", "git", "ffmpeg", "espeak-ng"], + "Failed to install Linux dependencies.") + else: + print("Error: Unsupported Linux distribution. Please install the dependencies manually.") + sys.exit(1) + +def install_macos_dependencies(): + if is_tool("brew"): + run_command(["brew", "install", "git", "ffmpeg", "espeak"], + "Failed to install macOS dependencies.") + else: + print("Error: Homebrew not found. Please install Homebrew and try again.") + sys.exit(1) + +def install_dependencies(): + if sys.platform == "win32": + print(description) + print("Please install the required dependencies manually on Windows.") + sys.exit(1) + elif sys.platform == "darwin": + install_macos_dependencies() + elif sys.platform.startswith("linux"): + install_linux_dependencies() + else: + print(f"Unsupported platform: {sys.platform}") + sys.exit(1) + +def create_conda_environment(): + run_command(["conda", "create", "-y", "-n", "voicecraft", "python=3.9.16"], + "Failed to create Conda environment.") + run_command(["conda", "init", "bash"], + "Failed to initialize Conda.") + run_command(["source", "~/.bashrc"], + "Failed to source .bashrc.") + run_command(["conda", "activate", "voicecraft"], + "Failed to activate Conda environment.") + +def install_python_dependencies(): + conda_packages = [ + "montreal-forced-aligner=2.2.17", + "openfst=1.8.2", + "kaldi=5.5.1068" + ] + pip_packages = [ + "torch==2.0.1", + "tensorboard==2.16.2", + "phonemizer==3.2.1", + "torchaudio==2.0.2", + "datasets==2.16.0", + "torchmetrics==0.11.1", + "-e git+https://github.com/facebookresearch/audiocraft.git@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft" + ] + + run_command(["conda", "install", "-y", "-c", "conda-forge"] + conda_packages, + "Failed to install Conda packages.") + run_command(["pip", "install"] + pip_packages, + "Failed to install Python packages.") + +def download_models(ckpt_fn, encodec_fn): + if not os.path.exists(ckpt_fn): + run_command(["wget", f"https://huggingface.co/pyp1/VoiceCraft/resolve/main/{os.path.basename(ckpt_fn)}?download=true"], + f"Failed to download {ckpt_fn}.") + run_command(["mv", f"{os.path.basename(ckpt_fn)}?download=true", ckpt_fn], + f"Failed to move {ckpt_fn}.") + + if not os.path.exists(encodec_fn): + run_command(["wget", "https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th"], + f"Failed to download {encodec_fn}.") + run_command(["mv", "encodec_4cb2048_giga.th", encodec_fn], + f"Failed to move {encodec_fn}.") + +def check_python_dependencies(): + dependencies = [ + "torch", + "torchaudio", + "data.tokenizer", + "models.voicecraft", + "inference_tts_scale" + ] + + missing_dependencies = [] + for dependency in dependencies: + try: + importlib.import_module(dependency) + except ImportError: + missing_dependencies.append(dependency) + + if missing_dependencies: + print("Missing Python dependencies:", missing_dependencies) + install_python_dependencies() + +def parse_arguments(): + parser = argparse.ArgumentParser(description=description, formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument("-a", "--audio", required=True, help="Path to the input audio file used as a reference for the voice.") + parser.add_argument("-t", "--transcript", required=True, help="Path to the text file containing the transcript to be synthesized.") + parser.add_argument("--output_dir", default="./demo/generated_tts", help="Output directory where the generated audio files will be saved. Default: './demo/generated_tts'") + parser.add_argument("--left_margin", type=float, default=0.08, help="Left margin of the audio segment used for speech editing. This is not used for text-to-speech synthesis. Default: 0.08") + parser.add_argument("--right_margin", type=float, default=0.08, help="Right margin of the audio segment used for speech editing. This is not used for text-to-speech synthesis. Default: 0.08") + parser.add_argument("--codec_audio_sr", type=int, default=16000, help="Sample rate of the audio codec used for encoding and decoding. Default: 16000") + parser.add_argument("--codec_sr", type=int, default=50, help="Sample rate of the codec used for encoding and decoding. Default: 50") + parser.add_argument("--top_k", type=int, default=0, help="Top-k sampling parameter. It limits the number of highest probability tokens to consider during generation. A higher value (e.g., 50) will result in more diverse but potentially less coherent speech, while a lower value (e.g., 1) will result in more conservative and repetitive speech. Setting it to 0 disables top-k sampling. Default: 0") + parser.add_argument("--top_p", type=float, default=0.8, help="Top-p sampling parameter. It controls the diversity of the generated audio by truncating the least likely tokens whose cumulative probability exceeds 'p'. Lower values (e.g., 0.5) will result in more conservative and repetitive speech, while higher values (e.g., 0.9) will result in more diverse speech. Default: 0.8") + parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature. It controls the randomness of the generated speech. Higher values (e.g., 1.5) will result in more expressive and varied speech, while lower values (e.g., 0.5) will result in more monotonous and conservative speech. Default: 1.0") + parser.add_argument("--kvcache", type=int, default=1, help="Key-value cache size used for caching intermediate results. A larger cache size may improve performance but consume more memory. Default: 1") + parser.add_argument("--seed", type=int, default=1, help="Random seed for reproducibility. Use the same seed value to generate the same output for a given input. Default: 1") + parser.add_argument("--stop_repetition", type=int, default=3, help="Stop repetition threshold. It controls the number of consecutive repetitions allowed in the generated speech. Lower values (e.g., 1 or 2) will result in less repetitive speech but may also lead to abrupt stopping. Higher values (e.g., 4 or 5) will allow more repetitions. Default: 3") + parser.add_argument("--sample_batch_size", type=int, default=4, help="Number of audio samples generated in parallel. Increasing this value may improve the quality of the generated speech by reducing long silences or unnaturally stretched words, but it will also increase memory usage. Default: 4") + return parser.parse_args() + +def main(): + args = parse_arguments() + + install_dependencies() + create_conda_environment() + check_python_dependencies() + + orig_audio = args.audio + orig_transcript = args.transcript + output_dir = args.output_dir + + # Create the output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + # Hyperparameters for inference + left_margin = args.left_margin + right_margin = args.right_margin + codec_audio_sr = args.codec_audio_sr + codec_sr = args.codec_sr + top_k = args.top_k + top_p = args.top_p + temperature = args.temperature + kvcache = args.kvcache + silence_tokens = [1388, 1898, 131] + seed = args.seed + stop_repetition = args.stop_repetition + sample_batch_size = args.sample_batch_size + + # Set the device based on available hardware + if torch.cuda.is_available(): + device = "cuda" + elif sys.platform == "darwin" and torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + + # Move audio and transcript to temp folder + temp_folder = "./demo/temp" + os.makedirs(temp_folder, exist_ok=True) + subprocess.run(["cp", orig_audio, temp_folder]) + filename = os.path.splitext(os.path.basename(orig_audio))[0] + with open(f"{temp_folder}/{filename}.txt", "w") as f: + f.write(orig_transcript) + + # Run MFA to get the alignment + align_temp = f"{temp_folder}/mfa_alignments" + os.makedirs(align_temp, exist_ok=True) + subprocess.run(["mfa", "model", "download", "dictionary", "english_us_arpa"]) + subprocess.run(["mfa", "model", "download", "acoustic", "english_us_arpa"]) + subprocess.run(["mfa", "align", "-v", "--clean", "-j", "1", "--output_format", "csv", temp_folder, "english_us_arpa", "english_us_arpa", align_temp]) + + audio_fn = f"{temp_folder}/{filename}.wav" + transcript_fn = f"{temp_folder}/{filename}.txt" + align_fn = f"{align_temp}/{filename}.csv" + + # Decide which part of the audio to use as prompt based on forced alignment + cut_off_sec = 3.01 # NOTE: According to forced-alignment file, the word "common" stops at 3.01 sec, this should be different for different audio + target_transcript = "But when I had approached so near to them The common I cannot believe that the same model can also do text to speech synthesis as well!" + + info = torchaudio.info(audio_fn) + audio_dur = info.num_frames / info.sample_rate + assert cut_off_sec < audio_dur, f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}" + prompt_end_frame = int(cut_off_sec * info.sample_rate) + + # Load model, tokenizer, and other necessary files + voicecraft_name = "giga830M.pth" + ckpt_fn = f"./pretrained_models/{voicecraft_name}" + encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th" + + if not os.path.exists(ckpt_fn): + subprocess.run(["wget", f"https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}?download=true"]) + subprocess.run(["mv", f"{voicecraft_name}?download=true", f"./pretrained_models/{voicecraft_name}"]) + + if not os.path.exists(encodec_fn): + subprocess.run(["wget", "https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th"]) + subprocess.run(["mv", "encodec_4cb2048_giga.th", "./pretrained_models/encodec_4cb2048_giga.th"]) + + ckpt = torch.load(ckpt_fn, map_location="cpu") + model = voicecraft.VoiceCraft(ckpt["config"]) + model.load_state_dict(ckpt["model"]) + model.to(device) + model.eval() + + phn2num = ckpt['phn2num'] + text_tokenizer = TextTokenizer(backend="espeak") + audio_tokenizer = AudioTokenizer(signature=encodec_fn) # will also put the neural codec model on gpu + + # Run the model to get the output + decode_config = { + 'top_k': top_k, 'top_p': top_p, 'temperature': temperature, + 'stop_repetition': stop_repetition, 'kvcache': kvcache, + "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr, + "silence_tokens": silence_tokens, "sample_batch_size": sample_batch_size + } + concated_audio, gen_audio = inference_one_sample( + model, ckpt["config"], phn2num, text_tokenizer, audio_tokenizer, + audio_fn, target_transcript, device, decode_config, prompt_end_frame + ) + + # Save segments for comparison + concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu() + + # Save the audio + seg_save_fn_gen = os.path.join(output_dir, f"{os.path.basename(orig_audio)[:-4]}_gen_seed{seed}.wav") + seg_save_fn_concat = os.path.join(output_dir, f"{os.path.basename(orig_audio)[:-4]}_concat_seed{seed}.wav") + + torchaudio.save(seg_save_fn_gen, gen_audio, codec_audio_sr) + torchaudio.save(seg_save_fn_concat, concated_audio, codec_audio_sr) + +if __name__ == "__main__": + main() \ No newline at end of file