mirror of
https://github.com/jasonppy/VoiceCraft.git
synced 2025-06-05 21:49:11 +02:00
replicate demo
This commit is contained in:
1
cog.yaml
1
cog.yaml
@ -8,7 +8,6 @@ build:
|
|||||||
- "libglib2.0-0"
|
- "libglib2.0-0"
|
||||||
- ffmpeg
|
- ffmpeg
|
||||||
- espeak-ng
|
- espeak-ng
|
||||||
# - cmake
|
|
||||||
python_version: "3.9.16"
|
python_version: "3.9.16"
|
||||||
python_packages:
|
python_packages:
|
||||||
- torch==2.0.1
|
- torch==2.0.1
|
||||||
|
156
predict.py
156
predict.py
@ -11,7 +11,9 @@ import torchaudio
|
|||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
os.environ["USER"] = getpass.getuser()
|
os.environ["USER"] = getpass.getuser()
|
||||||
|
|
||||||
from data.tokenizer import (
|
from data.tokenizer import (
|
||||||
@ -21,9 +23,14 @@ from data.tokenizer import (
|
|||||||
from cog import BasePredictor, Input, Path
|
from cog import BasePredictor, Input, Path
|
||||||
from models import voicecraft
|
from models import voicecraft
|
||||||
from inference_tts_scale import inference_one_sample
|
from inference_tts_scale import inference_one_sample
|
||||||
|
from edit_utils import get_span
|
||||||
|
from inference_speech_editing_scale import get_mask_interval
|
||||||
|
from inference_speech_editing_scale import (
|
||||||
|
inference_one_sample as inference_one_sample_editing,
|
||||||
|
)
|
||||||
|
|
||||||
ENV_NAME = "myenv"
|
ENV_NAME = "myenv"
|
||||||
# sys.path.append(f"/cog/miniconda/envs/{ENV_NAME}/lib/python3.10/site-packages")
|
|
||||||
|
|
||||||
MODEL_URL = "https://weights.replicate.delivery/default/VoiceCraft.tar"
|
MODEL_URL = "https://weights.replicate.delivery/default/VoiceCraft.tar"
|
||||||
MODEL_CACHE = "model_cache"
|
MODEL_CACHE = "model_cache"
|
||||||
@ -63,22 +70,33 @@ class Predictor(BasePredictor):
|
|||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
|
task: str = Input(
|
||||||
|
description="Choose a task. For zero-shot text-to-speech, you also need to specify the cut_off_sec of the original audio to be used for zero-shot generation and the transcript until the cut_off_sec",
|
||||||
|
choices=[
|
||||||
|
"speech_editing-substitution",
|
||||||
|
"speech_editing-insertion",
|
||||||
|
"speech_editing-sdeletion",
|
||||||
|
"zero-shot text-to-speech",
|
||||||
|
],
|
||||||
|
default="speech_editing-substitution",
|
||||||
|
),
|
||||||
orig_audio: Path = Input(description="Original audio file"),
|
orig_audio: Path = Input(description="Original audio file"),
|
||||||
orig_transcript: str = Input(
|
orig_transcript: str = Input(
|
||||||
description="Transcript of the original audio file. You can use models such as https://replicate.com/openai/whisper and https://replicate.com/vaibhavs10/incredibly-fast-whisper to get the transcript (and modify it if it's not accurate)",
|
description="Transcript of the original audio file. You can use models such as https://replicate.com/openai/whisper and https://replicate.com/vaibhavs10/incredibly-fast-whisper to get the transcript (and modify it if it's not accurate)",
|
||||||
),
|
),
|
||||||
cut_off_sec: float = Input(
|
|
||||||
description="The first seconds of the original audio that are used for zero-shot text-to-speech (TTS). 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec",
|
|
||||||
default=3.01,
|
|
||||||
),
|
|
||||||
orig_transcript_until_cutoff_time: str = Input(
|
|
||||||
description="Transcript of the original audio file until the cut_off_sec specified above. This process will be improved and made automatically later",
|
|
||||||
),
|
|
||||||
target_transcript: str = Input(
|
target_transcript: str = Input(
|
||||||
description="Transcript of the target audio file",
|
description="Transcript of the target audio file",
|
||||||
),
|
),
|
||||||
|
cut_off_sec: float = Input(
|
||||||
|
description="Valid/Required for zero-shot text-to-speech task. The first seconds of the original audio that are used for zero-shot text-to-speech (TTS). 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec",
|
||||||
|
default=3.01,
|
||||||
|
),
|
||||||
|
orig_transcript_until_cutoff_time: str = Input(
|
||||||
|
description="Valid/Required for zero-shot text-to-speech task. Transcript of the original audio file until the cut_off_sec specified above. This process will be improved and made automatically later",
|
||||||
|
default=None,
|
||||||
|
),
|
||||||
temperature: float = Input(
|
temperature: float = Input(
|
||||||
description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic,
|
description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic",
|
||||||
ge=0.01,
|
ge=0.01,
|
||||||
le=5,
|
le=5,
|
||||||
default=1,
|
default=1,
|
||||||
@ -89,6 +107,10 @@ class Predictor(BasePredictor):
|
|||||||
le=1.0,
|
le=1.0,
|
||||||
default=0.8,
|
default=0.8,
|
||||||
),
|
),
|
||||||
|
stop_repetition: int = Input(
|
||||||
|
default=-1,
|
||||||
|
description=" -1 means do not adjust prob of silence tokens. if there are long silence or unnaturally strecthed words, increase sample_batch_size to 2, 3 or even 4",
|
||||||
|
),
|
||||||
sampling_rate: int = Input(
|
sampling_rate: int = Input(
|
||||||
description="Specify the sampling rate of the audio codec", default=16000
|
description="Specify the sampling rate of the audio codec", default=16000
|
||||||
),
|
),
|
||||||
@ -97,20 +119,27 @@ class Predictor(BasePredictor):
|
|||||||
),
|
),
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""Run a single prediction on the model"""
|
"""Run a single prediction on the model"""
|
||||||
|
|
||||||
|
if task == "zero-shot text-to-speech":
|
||||||
|
assert (
|
||||||
|
orig_transcript_until_cutoff_time is not None
|
||||||
|
), "Please provide orig_transcript_until_cutoff_time for zero-shot text-to-speech task."
|
||||||
if seed is None:
|
if seed is None:
|
||||||
seed = int.from_bytes(os.urandom(2), "big")
|
seed = int.from_bytes(os.urandom(2), "big")
|
||||||
print(f"Using seed: {seed}")
|
print(f"Using seed: {seed}")
|
||||||
|
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
|
|
||||||
temp_folder = "exp_temp"
|
temp_folder = "exp_dir"
|
||||||
if os.path.exists(temp_folder):
|
if os.path.exists(temp_folder):
|
||||||
shutil.rmtree(temp_folder)
|
shutil.rmtree(temp_folder)
|
||||||
|
|
||||||
os.makedirs(temp_folder)
|
os.makedirs(temp_folder)
|
||||||
os.system(f"cp {str(orig_audio)} {temp_folder}")
|
|
||||||
# filename = os.path.splitext(orig_audio.split("/")[-1])[0]
|
filename = "orig_audio"
|
||||||
with open(f"{temp_folder}/orig_audio_file.txt", "w") as f:
|
shutil.copy(orig_audio, f"{temp_folder}/{filename}.wav")
|
||||||
|
|
||||||
|
with open(f"{temp_folder}/{filename}.txt", "w") as f:
|
||||||
f.write(orig_transcript)
|
f.write(orig_transcript)
|
||||||
|
|
||||||
# run MFA to get the alignment
|
# run MFA to get the alignment
|
||||||
@ -121,26 +150,61 @@ class Predictor(BasePredictor):
|
|||||||
subprocess.run(command, shell=True, check=True)
|
subprocess.run(command, shell=True, check=True)
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
print("Error:", e)
|
print("Error:", e)
|
||||||
|
raise RuntimeError("Error running Alignment")
|
||||||
|
|
||||||
print("Alignment done!")
|
print("Alignment done!")
|
||||||
|
|
||||||
audio_fn = str(orig_audio) # f"{temp_folder}/{filename}.wav"
|
align_fn = f"{align_temp}/{filename}.csv"
|
||||||
|
audio_fn = f"{temp_folder}/{filename}.wav"
|
||||||
info = torchaudio.info(audio_fn)
|
info = torchaudio.info(audio_fn)
|
||||||
audio_dur = info.num_frames / info.sample_rate
|
audio_dur = info.num_frames / info.sample_rate
|
||||||
|
|
||||||
assert (
|
# hyperparameters for inference
|
||||||
cut_off_sec < audio_dur
|
left_margin = 0.08
|
||||||
), f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}"
|
right_margin = 0.08
|
||||||
prompt_end_frame = int(cut_off_sec * info.sample_rate)
|
|
||||||
|
|
||||||
codec_sr = 50
|
codec_sr = 50
|
||||||
top_k = 0
|
top_k = 0
|
||||||
silence_tokens = [1388, 1898, 131]
|
silence_tokens = [1388, 1898, 131]
|
||||||
kvcache = 1 # NOTE if OOM, change this to 0, or try the 330M model
|
kvcache = 1 if task == "zero-shot text-to-speech" else 0
|
||||||
|
|
||||||
# NOTE adjust the below three arguments if the generation is not as good
|
|
||||||
stop_repetition = 3 # NOTE if the model generate long silence, reduce the stop_repetition to 3, 2 or even 1
|
|
||||||
sample_batch_size = 4 # NOTE: if the if there are long silence or unnaturally strecthed words, increase sample_batch_size to 5 or higher. What this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest. So if the speech rate of the generated is too fast change it to a smaller number.
|
sample_batch_size = 4 # NOTE: if the if there are long silence or unnaturally strecthed words, increase sample_batch_size to 5 or higher. What this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest. So if the speech rate of the generated is too fast change it to a smaller number.
|
||||||
|
|
||||||
|
if task == "":
|
||||||
|
assert (
|
||||||
|
cut_off_sec < audio_dur
|
||||||
|
), f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}"
|
||||||
|
prompt_end_frame = int(cut_off_sec * info.sample_rate)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
edit_type = task.split("-")[-1]
|
||||||
|
orig_span, new_span = get_span(
|
||||||
|
orig_transcript, target_transcript, edit_type
|
||||||
|
)
|
||||||
|
if orig_span[0] > orig_span[1]:
|
||||||
|
RuntimeError(f"example {audio_fn} failed")
|
||||||
|
if orig_span[0] == orig_span[1]:
|
||||||
|
orig_span_save = [orig_span[0]]
|
||||||
|
else:
|
||||||
|
orig_span_save = orig_span
|
||||||
|
if new_span[0] == new_span[1]:
|
||||||
|
new_span_save = [new_span[0]]
|
||||||
|
else:
|
||||||
|
new_span_save = new_span
|
||||||
|
orig_span_save = ",".join([str(item) for item in orig_span_save])
|
||||||
|
new_span_save = ",".join([str(item) for item in new_span_save])
|
||||||
|
start, end = get_mask_interval(align_fn, orig_span_save, edit_type)
|
||||||
|
|
||||||
|
# span in codec frames
|
||||||
|
morphed_span = (
|
||||||
|
max(start - left_margin, 1 / codec_sr),
|
||||||
|
min(end + right_margin, audio_dur),
|
||||||
|
) # in seconds
|
||||||
|
mask_interval = [
|
||||||
|
[round(morphed_span[0] * codec_sr), round(morphed_span[1] * codec_sr)]
|
||||||
|
]
|
||||||
|
mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
|
||||||
|
|
||||||
decode_config = {
|
decode_config = {
|
||||||
"top_k": top_k,
|
"top_k": top_k,
|
||||||
"top_p": top_p,
|
"top_p": top_p,
|
||||||
@ -150,27 +214,45 @@ class Predictor(BasePredictor):
|
|||||||
"codec_audio_sr": sampling_rate,
|
"codec_audio_sr": sampling_rate,
|
||||||
"codec_sr": codec_sr,
|
"codec_sr": codec_sr,
|
||||||
"silence_tokens": silence_tokens,
|
"silence_tokens": silence_tokens,
|
||||||
"sample_batch_size": sample_batch_size,
|
|
||||||
}
|
}
|
||||||
concated_audio, gen_audio = inference_one_sample(
|
|
||||||
self.model,
|
if task == "zero-shot text-to-speech":
|
||||||
self.ckpt["config"],
|
decode_config["sample_batch_size"] = sample_batch_size
|
||||||
self.phn2num,
|
|
||||||
self.text_tokenizer,
|
concated_audio, gen_audio = inference_one_sample(
|
||||||
self.audio_tokenizer,
|
self.model,
|
||||||
audio_fn,
|
self.ckpt["config"],
|
||||||
orig_transcript_until_cutoff_time.strip() + "" + target_transcript.strip(),
|
self.phn2num,
|
||||||
self.device,
|
self.text_tokenizer,
|
||||||
decode_config,
|
self.audio_tokenizer,
|
||||||
prompt_end_frame,
|
audio_fn,
|
||||||
)
|
orig_transcript_until_cutoff_time.strip()
|
||||||
|
+ ""
|
||||||
|
+ target_transcript.strip(),
|
||||||
|
self.device,
|
||||||
|
decode_config,
|
||||||
|
prompt_end_frame,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
orig_audio, gen_audio = inference_one_sample_editing(
|
||||||
|
self.model,
|
||||||
|
self.ckpt["config"],
|
||||||
|
self.phn2num,
|
||||||
|
self.text_tokenizer,
|
||||||
|
self.audio_tokenizer,
|
||||||
|
audio_fn,
|
||||||
|
target_transcript,
|
||||||
|
mask_interval,
|
||||||
|
self.device,
|
||||||
|
decode_config,
|
||||||
|
)
|
||||||
|
|
||||||
# save segments for comparison
|
# save segments for comparison
|
||||||
concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()
|
gen_audio = gen_audio[0].cpu()
|
||||||
|
|
||||||
out = "/tmp/out.wav"
|
out = "/tmp/out.wav"
|
||||||
torchaudio.save(out, gen_audio, sampling_rate)
|
torchaudio.save(out, gen_audio, sampling_rate)
|
||||||
torchaudio.save("out.wav", gen_audio, sampling_rate)
|
|
||||||
return Path(out)
|
return Path(out)
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user