2024-04-05 17:20:11 +02:00
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
import os
import time
import numpy as np
import random
import getpass
import torch
import torchaudio
import shutil
import subprocess
import sys
2024-04-05 19:23:39 +02:00
import warnings
2024-04-05 17:20:11 +02:00
2024-04-05 19:23:39 +02:00
warnings . filterwarnings ( " ignore " , category = UserWarning )
2024-04-05 17:20:11 +02:00
os . environ [ " USER " ] = getpass . getuser ( )
from data . tokenizer import (
AudioTokenizer ,
TextTokenizer ,
)
from cog import BasePredictor , Input , Path
from models import voicecraft
from inference_tts_scale import inference_one_sample
2024-04-05 19:23:39 +02:00
from edit_utils import get_span
from inference_speech_editing_scale import get_mask_interval
from inference_speech_editing_scale import (
inference_one_sample as inference_one_sample_editing ,
)
2024-04-05 17:20:11 +02:00
ENV_NAME = " myenv "
2024-04-05 19:23:39 +02:00
2024-04-05 17:20:11 +02:00
MODEL_URL = " https://weights.replicate.delivery/default/VoiceCraft.tar "
MODEL_CACHE = " model_cache "
def download_weights ( url , dest ) :
start = time . time ( )
print ( " downloading url: " , url )
print ( " downloading to: " , dest )
subprocess . check_call ( [ " pget " , " -x " , url , dest ] , close_fds = False )
print ( " downloading took: " , time . time ( ) - start )
class Predictor ( BasePredictor ) :
def setup ( self ) :
""" Load the model into memory to make running multiple predictions efficient """
self . device = " cuda "
voicecraft_name = " giga830M.pth " # or giga330M.pth
if not os . path . exists ( MODEL_CACHE ) :
download_weights ( MODEL_URL , MODEL_CACHE )
encodec_fn = f " { MODEL_CACHE } /encodec_4cb2048_giga.th "
ckpt_fn = f " { MODEL_CACHE } / { voicecraft_name } "
self . ckpt = torch . load ( ckpt_fn , map_location = " cpu " )
self . model = voicecraft . VoiceCraft ( self . ckpt [ " config " ] )
self . model . load_state_dict ( self . ckpt [ " model " ] )
self . model . to ( self . device )
self . model . eval ( )
self . phn2num = self . ckpt [ " phn2num " ]
self . text_tokenizer = TextTokenizer ( backend = " espeak " )
self . audio_tokenizer = AudioTokenizer ( signature = encodec_fn , device = self . device )
def predict (
self ,
2024-04-05 19:23:39 +02:00
task : str = Input (
description = " Choose a task. For zero-shot text-to-speech, you also need to specify the cut_off_sec of the original audio to be used for zero-shot generation and the transcript until the cut_off_sec " ,
choices = [
" speech_editing-substitution " ,
" speech_editing-insertion " ,
" speech_editing-sdeletion " ,
" zero-shot text-to-speech " ,
] ,
default = " speech_editing-substitution " ,
) ,
2024-04-05 17:20:11 +02:00
orig_audio : Path = Input ( description = " Original audio file " ) ,
orig_transcript : str = Input (
description = " Transcript of the original audio file. You can use models such as https://replicate.com/openai/whisper and https://replicate.com/vaibhavs10/incredibly-fast-whisper to get the transcript (and modify it if it ' s not accurate) " ,
) ,
2024-04-05 19:23:39 +02:00
target_transcript : str = Input (
description = " Transcript of the target audio file " ,
) ,
2024-04-05 17:20:11 +02:00
cut_off_sec : float = Input (
2024-04-05 19:23:39 +02:00
description = " Valid/Required for zero-shot text-to-speech task. The first seconds of the original audio that are used for zero-shot text-to-speech (TTS). 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec " ,
2024-04-05 17:20:11 +02:00
default = 3.01 ,
) ,
orig_transcript_until_cutoff_time : str = Input (
2024-04-05 19:23:39 +02:00
description = " Valid/Required for zero-shot text-to-speech task. Transcript of the original audio file until the cut_off_sec specified above. This process will be improved and made automatically later " ,
default = None ,
2024-04-05 17:20:11 +02:00
) ,
temperature : float = Input (
2024-04-05 19:23:39 +02:00
description = " Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic " ,
2024-04-05 17:20:11 +02:00
ge = 0.01 ,
le = 5 ,
default = 1 ,
) ,
top_p : float = Input (
description = " When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens " ,
ge = 0.0 ,
le = 1.0 ,
default = 0.8 ,
) ,
2024-04-05 19:23:39 +02:00
stop_repetition : int = Input (
default = - 1 ,
description = " -1 means do not adjust prob of silence tokens. if there are long silence or unnaturally strecthed words, increase sample_batch_size to 2, 3 or even 4 " ,
) ,
2024-04-05 17:20:11 +02:00
sampling_rate : int = Input (
description = " Specify the sampling rate of the audio codec " , default = 16000
) ,
seed : int = Input (
description = " Random seed. Leave blank to randomize the seed " , default = None
) ,
) - > Path :
""" Run a single prediction on the model """
2024-04-05 19:23:39 +02:00
if task == " zero-shot text-to-speech " :
assert (
orig_transcript_until_cutoff_time is not None
) , " Please provide orig_transcript_until_cutoff_time for zero-shot text-to-speech task. "
2024-04-05 17:20:11 +02:00
if seed is None :
seed = int . from_bytes ( os . urandom ( 2 ) , " big " )
print ( f " Using seed: { seed } " )
seed_everything ( seed )
2024-04-05 19:23:39 +02:00
temp_folder = " exp_dir "
2024-04-05 17:20:11 +02:00
if os . path . exists ( temp_folder ) :
shutil . rmtree ( temp_folder )
os . makedirs ( temp_folder )
2024-04-05 19:23:39 +02:00
filename = " orig_audio "
shutil . copy ( orig_audio , f " { temp_folder } / { filename } .wav " )
with open ( f " { temp_folder } / { filename } .txt " , " w " ) as f :
2024-04-05 17:20:11 +02:00
f . write ( orig_transcript )
# run MFA to get the alignment
align_temp = f " { temp_folder } /mfa_alignments "
command = f ' /bin/bash -c " source /cog/miniconda/bin/activate && conda activate { ENV_NAME } && mfa align -v --clean -j 1 --output_format csv { temp_folder } english_us_arpa english_us_arpa { align_temp } " '
try :
subprocess . run ( command , shell = True , check = True )
except subprocess . CalledProcessError as e :
print ( " Error: " , e )
2024-04-05 19:23:39 +02:00
raise RuntimeError ( " Error running Alignment " )
2024-04-05 17:20:11 +02:00
print ( " Alignment done! " )
2024-04-05 19:23:39 +02:00
align_fn = f " { align_temp } / { filename } .csv "
audio_fn = f " { temp_folder } / { filename } .wav "
2024-04-05 17:20:11 +02:00
info = torchaudio . info ( audio_fn )
audio_dur = info . num_frames / info . sample_rate
2024-04-05 19:23:39 +02:00
# hyperparameters for inference
left_margin = 0.08
right_margin = 0.08
2024-04-05 17:20:11 +02:00
codec_sr = 50
top_k = 0
silence_tokens = [ 1388 , 1898 , 131 ]
2024-04-05 19:23:39 +02:00
kvcache = 1 if task == " zero-shot text-to-speech " else 0
2024-04-05 17:20:11 +02:00
sample_batch_size = 4 # NOTE: if the if there are long silence or unnaturally strecthed words, increase sample_batch_size to 5 or higher. What this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest. So if the speech rate of the generated is too fast change it to a smaller number.
2024-04-05 19:23:39 +02:00
if task == " " :
assert (
cut_off_sec < audio_dur
) , f " cut_off_sec { cut_off_sec } is larger than the audio duration { audio_dur } "
prompt_end_frame = int ( cut_off_sec * info . sample_rate )
else :
edit_type = task . split ( " - " ) [ - 1 ]
orig_span , new_span = get_span (
orig_transcript , target_transcript , edit_type
)
if orig_span [ 0 ] > orig_span [ 1 ] :
RuntimeError ( f " example { audio_fn } failed " )
if orig_span [ 0 ] == orig_span [ 1 ] :
orig_span_save = [ orig_span [ 0 ] ]
else :
orig_span_save = orig_span
if new_span [ 0 ] == new_span [ 1 ] :
new_span_save = [ new_span [ 0 ] ]
else :
new_span_save = new_span
orig_span_save = " , " . join ( [ str ( item ) for item in orig_span_save ] )
new_span_save = " , " . join ( [ str ( item ) for item in new_span_save ] )
start , end = get_mask_interval ( align_fn , orig_span_save , edit_type )
# span in codec frames
morphed_span = (
max ( start - left_margin , 1 / codec_sr ) ,
min ( end + right_margin , audio_dur ) ,
) # in seconds
mask_interval = [
[ round ( morphed_span [ 0 ] * codec_sr ) , round ( morphed_span [ 1 ] * codec_sr ) ]
]
mask_interval = torch . LongTensor ( mask_interval ) # [M,2], M==1 for now
2024-04-05 17:20:11 +02:00
decode_config = {
" top_k " : top_k ,
" top_p " : top_p ,
" temperature " : temperature ,
" stop_repetition " : stop_repetition ,
" kvcache " : kvcache ,
" codec_audio_sr " : sampling_rate ,
" codec_sr " : codec_sr ,
" silence_tokens " : silence_tokens ,
}
2024-04-05 19:23:39 +02:00
if task == " zero-shot text-to-speech " :
decode_config [ " sample_batch_size " ] = sample_batch_size
concated_audio , gen_audio = inference_one_sample (
self . model ,
self . ckpt [ " config " ] ,
self . phn2num ,
self . text_tokenizer ,
self . audio_tokenizer ,
audio_fn ,
orig_transcript_until_cutoff_time . strip ( )
+ " "
+ target_transcript . strip ( ) ,
self . device ,
decode_config ,
prompt_end_frame ,
)
else :
orig_audio , gen_audio = inference_one_sample_editing (
self . model ,
self . ckpt [ " config " ] ,
self . phn2num ,
self . text_tokenizer ,
self . audio_tokenizer ,
audio_fn ,
target_transcript ,
mask_interval ,
self . device ,
decode_config ,
)
2024-04-05 17:20:11 +02:00
# save segments for comparison
2024-04-05 19:23:39 +02:00
gen_audio = gen_audio [ 0 ] . cpu ( )
2024-04-05 17:20:11 +02:00
out = " /tmp/out.wav "
torchaudio . save ( out , gen_audio , sampling_rate )
return Path ( out )
def seed_everything ( seed ) :
os . environ [ " PYTHONHASHSEED " ] = str ( seed )
random . seed ( seed )
np . random . seed ( seed )
torch . manual_seed ( seed )
torch . cuda . manual_seed ( seed )
torch . backends . cudnn . benchmark = False
torch . backends . cudnn . deterministic = True