Compare commits
2 Commits
9ffb152332
...
8c1f9a9dc9
Author | SHA1 | Date |
---|---|---|
Chenxi | 8c1f9a9dc9 | |
chenxwh | 0da8ee4b7a |
2
cog.yaml
2
cog.yaml
|
@ -17,6 +17,8 @@ build:
|
|||
- phonemizer==3.2.1
|
||||
- datasets==2.16.0
|
||||
- torchmetrics==0.11.1
|
||||
- whisperx==3.1.1
|
||||
- openai-whisper>=20231117
|
||||
run:
|
||||
- curl -O https://repo.anaconda.com/miniconda/Miniconda3-py310_23.3.1-0-Linux-x86_64.sh
|
||||
- bash Miniconda3-py310_23.3.1-0-Linux-x86_64.sh -b -p /cog/miniconda
|
||||
|
|
206
predict.py
206
predict.py
|
@ -2,16 +2,20 @@
|
|||
# https://github.com/replicate/cog/blob/main/docs/python.md
|
||||
|
||||
import os
|
||||
import stat
|
||||
import time
|
||||
import numpy as np
|
||||
import warnings
|
||||
import random
|
||||
import getpass
|
||||
import torch
|
||||
import torchaudio
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import warnings
|
||||
import torch
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
|
||||
from whisper.model import Whisper, ModelDimensions
|
||||
from whisper.tokenizer import get_tokenizer
|
||||
from cog import BasePredictor, Input, Path, BaseModel
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
os.environ["USER"] = getpass.getuser()
|
||||
|
@ -20,7 +24,6 @@ from data.tokenizer import (
|
|||
AudioTokenizer,
|
||||
TextTokenizer,
|
||||
)
|
||||
from cog import BasePredictor, Input, Path
|
||||
from models import voicecraft
|
||||
from inference_tts_scale import inference_one_sample
|
||||
from edit_utils import get_span
|
||||
|
@ -31,11 +34,38 @@ from inference_speech_editing_scale import (
|
|||
|
||||
ENV_NAME = "myenv"
|
||||
|
||||
|
||||
MODEL_URL = "https://weights.replicate.delivery/default/VoiceCraft.tar"
|
||||
MODEL_URL = "https://weights.replicate.delivery/default/pyp1/VoiceCraft.tar"
|
||||
MODEL_CACHE = "model_cache"
|
||||
|
||||
|
||||
class ModelOutput(BaseModel):
|
||||
whisper_transcript_orig_audio: str
|
||||
generated_audio: Path
|
||||
|
||||
|
||||
class WhisperModel:
|
||||
def __init__(self, model_cache, model_name="base.en", device="cuda"):
|
||||
|
||||
with open(f"{model_cache}/{model_name}.pt", "rb") as fp:
|
||||
checkpoint = torch.load(fp, map_location="cpu")
|
||||
dims = ModelDimensions(**checkpoint["dims"])
|
||||
self.model = Whisper(dims)
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
self.model.to(device)
|
||||
|
||||
tokenizer = get_tokenizer(multilingual=False)
|
||||
self.supress_tokens = [-1] + [
|
||||
i
|
||||
for i in range(tokenizer.eot)
|
||||
if all(c in "0123456789" for c in tokenizer.decode([i]).removeprefix(" "))
|
||||
]
|
||||
|
||||
def transcribe(self, audio_path):
|
||||
return self.model.transcribe(
|
||||
audio_path, suppress_tokens=self.supress_tokens, word_timestamps=True
|
||||
)["segments"]
|
||||
|
||||
|
||||
def download_weights(url, dest):
|
||||
start = time.time()
|
||||
print("downloading url: ", url)
|
||||
|
@ -49,56 +79,87 @@ class Predictor(BasePredictor):
|
|||
"""Load the model into memory to make running multiple predictions efficient"""
|
||||
self.device = "cuda"
|
||||
|
||||
voicecraft_name = "giga830M.pth" # or giga330M.pth
|
||||
|
||||
if not os.path.exists(MODEL_CACHE):
|
||||
download_weights(MODEL_URL, MODEL_CACHE)
|
||||
|
||||
encodec_fn = f"{MODEL_CACHE}/encodec_4cb2048_giga.th"
|
||||
ckpt_fn = f"{MODEL_CACHE}/{voicecraft_name}"
|
||||
self.models, self.ckpt, self.phn2num = {}, {}, {}
|
||||
for voicecraft_name in [
|
||||
"giga830M.pth",
|
||||
"giga330M.pth",
|
||||
"gigaHalfLibri330M_TTSEnhanced_max16s.pth",
|
||||
]:
|
||||
ckpt_fn = f"{MODEL_CACHE}/{voicecraft_name}"
|
||||
|
||||
self.ckpt = torch.load(ckpt_fn, map_location="cpu")
|
||||
self.model = voicecraft.VoiceCraft(self.ckpt["config"])
|
||||
self.model.load_state_dict(self.ckpt["model"])
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
self.ckpt[voicecraft_name] = torch.load(ckpt_fn, map_location="cpu")
|
||||
self.models[voicecraft_name] = voicecraft.VoiceCraft(
|
||||
self.ckpt[voicecraft_name]["config"]
|
||||
)
|
||||
self.models[voicecraft_name].load_state_dict(
|
||||
self.ckpt[voicecraft_name]["model"]
|
||||
)
|
||||
self.models[voicecraft_name].to(self.device)
|
||||
self.models[voicecraft_name].eval()
|
||||
|
||||
self.phn2num = self.ckpt["phn2num"]
|
||||
self.phn2num[voicecraft_name] = self.ckpt[voicecraft_name]["phn2num"]
|
||||
|
||||
self.text_tokenizer = TextTokenizer(backend="espeak")
|
||||
self.audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=self.device)
|
||||
self.transcribe_models = {
|
||||
k: WhisperModel(MODEL_CACHE, k, self.device)
|
||||
for k in ["base.en", "small.en", "medium.en"]
|
||||
}
|
||||
|
||||
def predict(
|
||||
self,
|
||||
task: str = Input(
|
||||
description="Choose a task. For zero-shot text-to-speech, you also need to specify the cut_off_sec of the original audio to be used for zero-shot generation and the transcript until the cut_off_sec",
|
||||
description="Choose a task",
|
||||
choices=[
|
||||
"speech_editing-substitution",
|
||||
"speech_editing-insertion",
|
||||
"speech_editing-deletion",
|
||||
"zero-shot text-to-speech",
|
||||
],
|
||||
default="speech_editing-substitution",
|
||||
default="zero-shot text-to-speech",
|
||||
),
|
||||
voicecraft_model: str = Input(
|
||||
description="Choose a model",
|
||||
choices=["giga830M.pth", "giga330M.pth", "giga330M_TTSEnhanced.pth"],
|
||||
default="giga330M_TTSEnhanced.pth",
|
||||
),
|
||||
orig_audio: Path = Input(
|
||||
description="Original audio file. WhisperX small.en model will be used for transcription"
|
||||
),
|
||||
orig_audio: Path = Input(description="Original audio file"),
|
||||
orig_transcript: str = Input(
|
||||
description="Transcript of the original audio file. You can use models such as https://replicate.com/openai/whisper and https://replicate.com/vaibhavs10/incredibly-fast-whisper to get the transcript (and modify it if it's not accurate)",
|
||||
description="Optionally provide the transcript of the input audio. Leave it blank to use the whisper model below to generate the transcript. Inaccurate transcription may lead to error TTS or speech editing",
|
||||
default="",
|
||||
),
|
||||
whisper_model: str = Input(
|
||||
description="If orig_transcript is not provided above, choose a Whisper model. Inaccurate transcription may lead to error TTS or speech editing. You can modify the generated transcript and provide it directly to ",
|
||||
choices=["base.en", "small.en", "medium.en"],
|
||||
default="base.en",
|
||||
),
|
||||
target_transcript: str = Input(
|
||||
description="Transcript of the target audio file",
|
||||
),
|
||||
cut_off_sec: float = Input(
|
||||
description="Valid/Required for zero-shot text-to-speech task. The first seconds of the original audio that are used for zero-shot text-to-speech (TTS). 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec",
|
||||
default=None,
|
||||
description="Only used for for zero-shot text-to-speech task. The first seconds of the original audio that are used for zero-shot text-to-speech. 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec",
|
||||
default=3.01,
|
||||
),
|
||||
orig_transcript_until_cutoff_time: str = Input(
|
||||
description="Valid/Required for zero-shot text-to-speech task. Transcript of the original audio file until the cut_off_sec specified above. This process will be improved and made automatically later",
|
||||
default=None,
|
||||
kvcache: int = Input(
|
||||
description="Set to 0 to use less VRAM, but with slower inference",
|
||||
default=1,
|
||||
),
|
||||
left_margin: float = Input(
|
||||
description="Margin to the left of the editing segment",
|
||||
default=0.08,
|
||||
),
|
||||
right_margin: float = Input(
|
||||
description="Margin to the right of the editing segment",
|
||||
default=0.08,
|
||||
),
|
||||
temperature: float = Input(
|
||||
description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic",
|
||||
ge=0.01,
|
||||
le=5,
|
||||
description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic. Do not recommend to change",
|
||||
default=1,
|
||||
),
|
||||
top_p: float = Input(
|
||||
|
@ -109,28 +170,33 @@ class Predictor(BasePredictor):
|
|||
),
|
||||
stop_repetition: int = Input(
|
||||
default=-1,
|
||||
description=" -1 means do not adjust prob of silence tokens. if there are long silence or unnaturally strecthed words, increase sample_batch_size to 2, 3 or even 4",
|
||||
description=" -1 means do not adjust prob of silence tokens. if there are long silence or unnaturally stretched words, increase sample_batch_size to 2, 3 or even 4",
|
||||
),
|
||||
sampling_rate: int = Input(
|
||||
description="Specify the sampling rate of the audio codec", default=16000
|
||||
sample_batch_size: int = Input(
|
||||
description="The higher the number, the faster the output will be. Under the hood, the model will generate this many samples and choose the shortest one",
|
||||
default=4,
|
||||
),
|
||||
seed: int = Input(
|
||||
description="Random seed. Leave blank to randomize the seed", default=None
|
||||
),
|
||||
) -> Path:
|
||||
) -> ModelOutput:
|
||||
"""Run a single prediction on the model"""
|
||||
|
||||
if task == "zero-shot text-to-speech":
|
||||
assert (
|
||||
orig_transcript_until_cutoff_time is not None
|
||||
and cut_off_sec is not None
|
||||
), "Please provide cut_off_sec and orig_transcript_until_cutoff_time for zero-shot text-to-speech task."
|
||||
if seed is None:
|
||||
seed = int.from_bytes(os.urandom(2), "big")
|
||||
print(f"Using seed: {seed}")
|
||||
|
||||
seed_everything(seed)
|
||||
|
||||
segments = self.transcribe_models[whisper_model].transcribe(str(orig_audio))
|
||||
state = get_transcribe_state(segments)
|
||||
whisper_transcript = state["transcript"].strip()
|
||||
|
||||
if len(orig_transcript.strip()) == 0:
|
||||
orig_transcript = whisper_transcript
|
||||
|
||||
print(f"The transcript from the Whisper model: {whisper_transcript}")
|
||||
|
||||
temp_folder = "exp_dir"
|
||||
if os.path.exists(temp_folder):
|
||||
shutil.rmtree(temp_folder)
|
||||
|
@ -161,14 +227,13 @@ class Predictor(BasePredictor):
|
|||
audio_dur = info.num_frames / info.sample_rate
|
||||
|
||||
# hyperparameters for inference
|
||||
left_margin = 0.08
|
||||
right_margin = 0.08
|
||||
codec_audio_sr = 16000
|
||||
codec_sr = 50
|
||||
top_k = 0
|
||||
silence_tokens = [1388, 1898, 131]
|
||||
kvcache = 1 if task == "zero-shot text-to-speech" else 0
|
||||
|
||||
sample_batch_size = 4 # NOTE: if the if there are long silence or unnaturally strecthed words, increase sample_batch_size to 5 or higher. What this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest. So if the speech rate of the generated is too fast change it to a smaller number.
|
||||
if voicecraft_model == "giga330M_TTSEnhanced.pth":
|
||||
voicecraft_model = "gigaHalfLibri330M_TTSEnhanced_max16s.pth"
|
||||
|
||||
if task == "zero-shot text-to-speech":
|
||||
assert (
|
||||
|
@ -176,8 +241,11 @@ class Predictor(BasePredictor):
|
|||
), f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}"
|
||||
prompt_end_frame = int(cut_off_sec * info.sample_rate)
|
||||
|
||||
idx = find_closest_cut_off_word(state["word_bounds"], cut_off_sec)
|
||||
orig_transcript_until_cutoff_time = "".join(
|
||||
[word_bound["word"] for word_bound in state["word_bounds"][:idx]]
|
||||
)
|
||||
else:
|
||||
|
||||
edit_type = task.split("-")[-1]
|
||||
orig_span, new_span = get_span(
|
||||
orig_transcript, target_transcript, edit_type
|
||||
|
@ -212,18 +280,17 @@ class Predictor(BasePredictor):
|
|||
"temperature": temperature,
|
||||
"stop_repetition": stop_repetition,
|
||||
"kvcache": kvcache,
|
||||
"codec_audio_sr": sampling_rate,
|
||||
"codec_audio_sr": codec_audio_sr,
|
||||
"codec_sr": codec_sr,
|
||||
"silence_tokens": silence_tokens,
|
||||
}
|
||||
|
||||
if task == "zero-shot text-to-speech":
|
||||
decode_config["sample_batch_size"] = sample_batch_size
|
||||
|
||||
concated_audio, gen_audio = inference_one_sample(
|
||||
self.model,
|
||||
self.ckpt["config"],
|
||||
self.phn2num,
|
||||
_, gen_audio = inference_one_sample(
|
||||
self.models[voicecraft_model],
|
||||
self.ckpt[voicecraft_model]["config"],
|
||||
self.phn2num[voicecraft_model],
|
||||
self.text_tokenizer,
|
||||
self.audio_tokenizer,
|
||||
audio_fn,
|
||||
|
@ -234,12 +301,11 @@ class Predictor(BasePredictor):
|
|||
decode_config,
|
||||
prompt_end_frame,
|
||||
)
|
||||
|
||||
else:
|
||||
orig_audio, gen_audio = inference_one_sample_editing(
|
||||
self.model,
|
||||
self.ckpt["config"],
|
||||
self.phn2num,
|
||||
_, gen_audio = inference_one_sample_editing(
|
||||
self.models[voicecraft_model],
|
||||
self.ckpt[voicecraft_model]["config"],
|
||||
self.phn2num[voicecraft_model],
|
||||
self.text_tokenizer,
|
||||
self.audio_tokenizer,
|
||||
audio_fn,
|
||||
|
@ -253,8 +319,10 @@ class Predictor(BasePredictor):
|
|||
gen_audio = gen_audio[0].cpu()
|
||||
|
||||
out = "/tmp/out.wav"
|
||||
torchaudio.save(out, gen_audio, sampling_rate)
|
||||
return Path(out)
|
||||
torchaudio.save(out, gen_audio, codec_audio_sr)
|
||||
return ModelOutput(
|
||||
generated_audio=Path(out), whisper_transcript_orig_audio=whisper_transcript
|
||||
)
|
||||
|
||||
|
||||
def seed_everything(seed):
|
||||
|
@ -265,3 +333,29 @@ def seed_everything(seed):
|
|||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def get_transcribe_state(segments):
|
||||
words_info = [word_info for segment in segments for word_info in segment["words"]]
|
||||
return {
|
||||
"transcript": " ".join([segment["text"].strip() for segment in segments]),
|
||||
"word_bounds": [
|
||||
{"word": word["word"], "start": word["start"], "end": word["end"]}
|
||||
for word in words_info
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def find_closest_cut_off_word(word_bounds, cut_off_sec):
|
||||
min_distance = float("inf")
|
||||
|
||||
for i, word_bound in enumerate(word_bounds):
|
||||
distance = abs(word_bound["start"] - cut_off_sec)
|
||||
|
||||
if distance < min_distance:
|
||||
min_distance = distance
|
||||
|
||||
if word_bound["end"] > cut_off_sec:
|
||||
break
|
||||
|
||||
return i
|
||||
|
|
Loading…
Reference in New Issue