VoiceCraft/predict.py

267 lines
9.9 KiB
Python
Raw Normal View History

2024-04-05 17:20:11 +02:00
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
import os
import time
import numpy as np
import random
import getpass
import torch
import torchaudio
import shutil
import subprocess
import sys
2024-04-05 19:23:39 +02:00
import warnings
2024-04-05 17:20:11 +02:00
2024-04-05 19:23:39 +02:00
warnings.filterwarnings("ignore", category=UserWarning)
2024-04-05 17:20:11 +02:00
os.environ["USER"] = getpass.getuser()
from data.tokenizer import (
AudioTokenizer,
TextTokenizer,
)
from cog import BasePredictor, Input, Path
from models import voicecraft
from inference_tts_scale import inference_one_sample
2024-04-05 19:23:39 +02:00
from edit_utils import get_span
from inference_speech_editing_scale import get_mask_interval
from inference_speech_editing_scale import (
inference_one_sample as inference_one_sample_editing,
)
2024-04-05 17:20:11 +02:00
ENV_NAME = "myenv"
2024-04-05 19:23:39 +02:00
2024-04-05 17:20:11 +02:00
MODEL_URL = "https://weights.replicate.delivery/default/VoiceCraft.tar"
MODEL_CACHE = "model_cache"
def download_weights(url, dest):
start = time.time()
print("downloading url: ", url)
print("downloading to: ", dest)
subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
print("downloading took: ", time.time() - start)
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
self.device = "cuda"
voicecraft_name = "giga830M.pth" # or giga330M.pth
if not os.path.exists(MODEL_CACHE):
download_weights(MODEL_URL, MODEL_CACHE)
encodec_fn = f"{MODEL_CACHE}/encodec_4cb2048_giga.th"
ckpt_fn = f"{MODEL_CACHE}/{voicecraft_name}"
self.ckpt = torch.load(ckpt_fn, map_location="cpu")
self.model = voicecraft.VoiceCraft(self.ckpt["config"])
self.model.load_state_dict(self.ckpt["model"])
self.model.to(self.device)
self.model.eval()
self.phn2num = self.ckpt["phn2num"]
self.text_tokenizer = TextTokenizer(backend="espeak")
self.audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=self.device)
def predict(
self,
2024-04-05 19:23:39 +02:00
task: str = Input(
description="Choose a task. For zero-shot text-to-speech, you also need to specify the cut_off_sec of the original audio to be used for zero-shot generation and the transcript until the cut_off_sec",
choices=[
"speech_editing-substitution",
"speech_editing-insertion",
"speech_editing-sdeletion",
"zero-shot text-to-speech",
],
default="speech_editing-substitution",
),
2024-04-05 17:20:11 +02:00
orig_audio: Path = Input(description="Original audio file"),
orig_transcript: str = Input(
description="Transcript of the original audio file. You can use models such as https://replicate.com/openai/whisper and https://replicate.com/vaibhavs10/incredibly-fast-whisper to get the transcript (and modify it if it's not accurate)",
),
2024-04-05 19:23:39 +02:00
target_transcript: str = Input(
description="Transcript of the target audio file",
),
2024-04-05 17:20:11 +02:00
cut_off_sec: float = Input(
2024-04-05 19:23:39 +02:00
description="Valid/Required for zero-shot text-to-speech task. The first seconds of the original audio that are used for zero-shot text-to-speech (TTS). 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec",
2024-04-05 17:20:11 +02:00
default=3.01,
),
orig_transcript_until_cutoff_time: str = Input(
2024-04-05 19:23:39 +02:00
description="Valid/Required for zero-shot text-to-speech task. Transcript of the original audio file until the cut_off_sec specified above. This process will be improved and made automatically later",
default=None,
2024-04-05 17:20:11 +02:00
),
temperature: float = Input(
2024-04-05 19:23:39 +02:00
description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic",
2024-04-05 17:20:11 +02:00
ge=0.01,
le=5,
default=1,
),
top_p: float = Input(
description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens",
ge=0.0,
le=1.0,
default=0.8,
),
2024-04-05 19:23:39 +02:00
stop_repetition: int = Input(
default=-1,
description=" -1 means do not adjust prob of silence tokens. if there are long silence or unnaturally strecthed words, increase sample_batch_size to 2, 3 or even 4",
),
2024-04-05 17:20:11 +02:00
sampling_rate: int = Input(
description="Specify the sampling rate of the audio codec", default=16000
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
) -> Path:
"""Run a single prediction on the model"""
2024-04-05 19:23:39 +02:00
if task == "zero-shot text-to-speech":
assert (
orig_transcript_until_cutoff_time is not None
), "Please provide orig_transcript_until_cutoff_time for zero-shot text-to-speech task."
2024-04-05 17:20:11 +02:00
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
seed_everything(seed)
2024-04-05 19:23:39 +02:00
temp_folder = "exp_dir"
2024-04-05 17:20:11 +02:00
if os.path.exists(temp_folder):
shutil.rmtree(temp_folder)
os.makedirs(temp_folder)
2024-04-05 19:23:39 +02:00
filename = "orig_audio"
shutil.copy(orig_audio, f"{temp_folder}/{filename}.wav")
with open(f"{temp_folder}/{filename}.txt", "w") as f:
2024-04-05 17:20:11 +02:00
f.write(orig_transcript)
# run MFA to get the alignment
align_temp = f"{temp_folder}/mfa_alignments"
command = f'/bin/bash -c "source /cog/miniconda/bin/activate && conda activate {ENV_NAME} && mfa align -v --clean -j 1 --output_format csv {temp_folder} english_us_arpa english_us_arpa {align_temp}"'
try:
subprocess.run(command, shell=True, check=True)
except subprocess.CalledProcessError as e:
print("Error:", e)
2024-04-05 19:23:39 +02:00
raise RuntimeError("Error running Alignment")
2024-04-05 17:20:11 +02:00
print("Alignment done!")
2024-04-05 19:23:39 +02:00
align_fn = f"{align_temp}/{filename}.csv"
audio_fn = f"{temp_folder}/{filename}.wav"
2024-04-05 17:20:11 +02:00
info = torchaudio.info(audio_fn)
audio_dur = info.num_frames / info.sample_rate
2024-04-05 19:23:39 +02:00
# hyperparameters for inference
left_margin = 0.08
right_margin = 0.08
2024-04-05 17:20:11 +02:00
codec_sr = 50
top_k = 0
silence_tokens = [1388, 1898, 131]
2024-04-05 19:23:39 +02:00
kvcache = 1 if task == "zero-shot text-to-speech" else 0
2024-04-05 17:20:11 +02:00
sample_batch_size = 4 # NOTE: if the if there are long silence or unnaturally strecthed words, increase sample_batch_size to 5 or higher. What this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest. So if the speech rate of the generated is too fast change it to a smaller number.
2024-04-05 19:23:39 +02:00
if task == "":
assert (
cut_off_sec < audio_dur
), f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}"
prompt_end_frame = int(cut_off_sec * info.sample_rate)
else:
edit_type = task.split("-")[-1]
orig_span, new_span = get_span(
orig_transcript, target_transcript, edit_type
)
if orig_span[0] > orig_span[1]:
RuntimeError(f"example {audio_fn} failed")
if orig_span[0] == orig_span[1]:
orig_span_save = [orig_span[0]]
else:
orig_span_save = orig_span
if new_span[0] == new_span[1]:
new_span_save = [new_span[0]]
else:
new_span_save = new_span
orig_span_save = ",".join([str(item) for item in orig_span_save])
new_span_save = ",".join([str(item) for item in new_span_save])
start, end = get_mask_interval(align_fn, orig_span_save, edit_type)
# span in codec frames
morphed_span = (
max(start - left_margin, 1 / codec_sr),
min(end + right_margin, audio_dur),
) # in seconds
mask_interval = [
[round(morphed_span[0] * codec_sr), round(morphed_span[1] * codec_sr)]
]
mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
2024-04-05 17:20:11 +02:00
decode_config = {
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"stop_repetition": stop_repetition,
"kvcache": kvcache,
"codec_audio_sr": sampling_rate,
"codec_sr": codec_sr,
"silence_tokens": silence_tokens,
}
2024-04-05 19:23:39 +02:00
if task == "zero-shot text-to-speech":
decode_config["sample_batch_size"] = sample_batch_size
concated_audio, gen_audio = inference_one_sample(
self.model,
self.ckpt["config"],
self.phn2num,
self.text_tokenizer,
self.audio_tokenizer,
audio_fn,
orig_transcript_until_cutoff_time.strip()
+ ""
+ target_transcript.strip(),
self.device,
decode_config,
prompt_end_frame,
)
else:
orig_audio, gen_audio = inference_one_sample_editing(
self.model,
self.ckpt["config"],
self.phn2num,
self.text_tokenizer,
self.audio_tokenizer,
audio_fn,
target_transcript,
mask_interval,
self.device,
decode_config,
)
2024-04-05 17:20:11 +02:00
# save segments for comparison
2024-04-05 19:23:39 +02:00
gen_audio = gen_audio[0].cpu()
2024-04-05 17:20:11 +02:00
out = "/tmp/out.wav"
torchaudio.save(out, gen_audio, sampling_rate)
return Path(out)
def seed_everything(seed):
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True