From 023d4b1c6c04786f39a2977c635338be33be7924 Mon Sep 17 00:00:00 2001 From: chenxwh Date: Fri, 5 Apr 2024 17:23:39 +0000 Subject: [PATCH] replicate demo --- cog.yaml | 1 - predict.py | 156 ++++++++++++++++++++++++++++++++++++++++------------- 2 files changed, 119 insertions(+), 38 deletions(-) diff --git a/cog.yaml b/cog.yaml index 12328c6..a771ad5 100644 --- a/cog.yaml +++ b/cog.yaml @@ -8,7 +8,6 @@ build: - "libglib2.0-0" - ffmpeg - espeak-ng - # - cmake python_version: "3.9.16" python_packages: - torch==2.0.1 diff --git a/predict.py b/predict.py index d7a4260..d8a2890 100644 --- a/predict.py +++ b/predict.py @@ -11,7 +11,9 @@ import torchaudio import shutil import subprocess import sys +import warnings +warnings.filterwarnings("ignore", category=UserWarning) os.environ["USER"] = getpass.getuser() from data.tokenizer import ( @@ -21,9 +23,14 @@ from data.tokenizer import ( from cog import BasePredictor, Input, Path from models import voicecraft from inference_tts_scale import inference_one_sample +from edit_utils import get_span +from inference_speech_editing_scale import get_mask_interval +from inference_speech_editing_scale import ( + inference_one_sample as inference_one_sample_editing, +) ENV_NAME = "myenv" -# sys.path.append(f"/cog/miniconda/envs/{ENV_NAME}/lib/python3.10/site-packages") + MODEL_URL = "https://weights.replicate.delivery/default/VoiceCraft.tar" MODEL_CACHE = "model_cache" @@ -63,22 +70,33 @@ class Predictor(BasePredictor): def predict( self, + task: str = Input( + description="Choose a task. For zero-shot text-to-speech, you also need to specify the cut_off_sec of the original audio to be used for zero-shot generation and the transcript until the cut_off_sec", + choices=[ + "speech_editing-substitution", + "speech_editing-insertion", + "speech_editing-sdeletion", + "zero-shot text-to-speech", + ], + default="speech_editing-substitution", + ), orig_audio: Path = Input(description="Original audio file"), orig_transcript: str = Input( description="Transcript of the original audio file. You can use models such as https://replicate.com/openai/whisper and https://replicate.com/vaibhavs10/incredibly-fast-whisper to get the transcript (and modify it if it's not accurate)", ), - cut_off_sec: float = Input( - description="The first seconds of the original audio that are used for zero-shot text-to-speech (TTS). 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec", - default=3.01, - ), - orig_transcript_until_cutoff_time: str = Input( - description="Transcript of the original audio file until the cut_off_sec specified above. This process will be improved and made automatically later", - ), target_transcript: str = Input( description="Transcript of the target audio file", ), + cut_off_sec: float = Input( + description="Valid/Required for zero-shot text-to-speech task. The first seconds of the original audio that are used for zero-shot text-to-speech (TTS). 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec", + default=3.01, + ), + orig_transcript_until_cutoff_time: str = Input( + description="Valid/Required for zero-shot text-to-speech task. Transcript of the original audio file until the cut_off_sec specified above. This process will be improved and made automatically later", + default=None, + ), temperature: float = Input( - description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic, + description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic", ge=0.01, le=5, default=1, @@ -89,6 +107,10 @@ class Predictor(BasePredictor): le=1.0, default=0.8, ), + stop_repetition: int = Input( + default=-1, + description=" -1 means do not adjust prob of silence tokens. if there are long silence or unnaturally strecthed words, increase sample_batch_size to 2, 3 or even 4", + ), sampling_rate: int = Input( description="Specify the sampling rate of the audio codec", default=16000 ), @@ -97,20 +119,27 @@ class Predictor(BasePredictor): ), ) -> Path: """Run a single prediction on the model""" + + if task == "zero-shot text-to-speech": + assert ( + orig_transcript_until_cutoff_time is not None + ), "Please provide orig_transcript_until_cutoff_time for zero-shot text-to-speech task." if seed is None: seed = int.from_bytes(os.urandom(2), "big") print(f"Using seed: {seed}") seed_everything(seed) - temp_folder = "exp_temp" + temp_folder = "exp_dir" if os.path.exists(temp_folder): shutil.rmtree(temp_folder) os.makedirs(temp_folder) - os.system(f"cp {str(orig_audio)} {temp_folder}") - # filename = os.path.splitext(orig_audio.split("/")[-1])[0] - with open(f"{temp_folder}/orig_audio_file.txt", "w") as f: + + filename = "orig_audio" + shutil.copy(orig_audio, f"{temp_folder}/{filename}.wav") + + with open(f"{temp_folder}/{filename}.txt", "w") as f: f.write(orig_transcript) # run MFA to get the alignment @@ -121,26 +150,61 @@ class Predictor(BasePredictor): subprocess.run(command, shell=True, check=True) except subprocess.CalledProcessError as e: print("Error:", e) + raise RuntimeError("Error running Alignment") + print("Alignment done!") - audio_fn = str(orig_audio) # f"{temp_folder}/{filename}.wav" + align_fn = f"{align_temp}/{filename}.csv" + audio_fn = f"{temp_folder}/{filename}.wav" info = torchaudio.info(audio_fn) audio_dur = info.num_frames / info.sample_rate - assert ( - cut_off_sec < audio_dur - ), f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}" - prompt_end_frame = int(cut_off_sec * info.sample_rate) - + # hyperparameters for inference + left_margin = 0.08 + right_margin = 0.08 codec_sr = 50 top_k = 0 silence_tokens = [1388, 1898, 131] - kvcache = 1 # NOTE if OOM, change this to 0, or try the 330M model + kvcache = 1 if task == "zero-shot text-to-speech" else 0 - # NOTE adjust the below three arguments if the generation is not as good - stop_repetition = 3 # NOTE if the model generate long silence, reduce the stop_repetition to 3, 2 or even 1 sample_batch_size = 4 # NOTE: if the if there are long silence or unnaturally strecthed words, increase sample_batch_size to 5 or higher. What this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest. So if the speech rate of the generated is too fast change it to a smaller number. + if task == "": + assert ( + cut_off_sec < audio_dur + ), f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}" + prompt_end_frame = int(cut_off_sec * info.sample_rate) + + else: + + edit_type = task.split("-")[-1] + orig_span, new_span = get_span( + orig_transcript, target_transcript, edit_type + ) + if orig_span[0] > orig_span[1]: + RuntimeError(f"example {audio_fn} failed") + if orig_span[0] == orig_span[1]: + orig_span_save = [orig_span[0]] + else: + orig_span_save = orig_span + if new_span[0] == new_span[1]: + new_span_save = [new_span[0]] + else: + new_span_save = new_span + orig_span_save = ",".join([str(item) for item in orig_span_save]) + new_span_save = ",".join([str(item) for item in new_span_save]) + start, end = get_mask_interval(align_fn, orig_span_save, edit_type) + + # span in codec frames + morphed_span = ( + max(start - left_margin, 1 / codec_sr), + min(end + right_margin, audio_dur), + ) # in seconds + mask_interval = [ + [round(morphed_span[0] * codec_sr), round(morphed_span[1] * codec_sr)] + ] + mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now + decode_config = { "top_k": top_k, "top_p": top_p, @@ -150,27 +214,45 @@ class Predictor(BasePredictor): "codec_audio_sr": sampling_rate, "codec_sr": codec_sr, "silence_tokens": silence_tokens, - "sample_batch_size": sample_batch_size, } - concated_audio, gen_audio = inference_one_sample( - self.model, - self.ckpt["config"], - self.phn2num, - self.text_tokenizer, - self.audio_tokenizer, - audio_fn, - orig_transcript_until_cutoff_time.strip() + "" + target_transcript.strip(), - self.device, - decode_config, - prompt_end_frame, - ) + + if task == "zero-shot text-to-speech": + decode_config["sample_batch_size"] = sample_batch_size + + concated_audio, gen_audio = inference_one_sample( + self.model, + self.ckpt["config"], + self.phn2num, + self.text_tokenizer, + self.audio_tokenizer, + audio_fn, + orig_transcript_until_cutoff_time.strip() + + "" + + target_transcript.strip(), + self.device, + decode_config, + prompt_end_frame, + ) + + else: + orig_audio, gen_audio = inference_one_sample_editing( + self.model, + self.ckpt["config"], + self.phn2num, + self.text_tokenizer, + self.audio_tokenizer, + audio_fn, + target_transcript, + mask_interval, + self.device, + decode_config, + ) # save segments for comparison - concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu() + gen_audio = gen_audio[0].cpu() out = "/tmp/out.wav" torchaudio.save(out, gen_audio, sampling_rate) - torchaudio.save("out.wav", gen_audio, sampling_rate) return Path(out)