Compare commits
16 Commits
23d82fae81
...
ad18fe1a5c
Author | SHA1 | Date |
---|---|---|
Chenxi | ad18fe1a5c | |
jason-on-salt-a40 | b10a245b44 | |
chenxwh | 2a2ee984b6 | |
chenxwh | 729d0ec69e | |
chenxwh | ef3dd8285b | |
chenxwh | 9746a1f60c | |
Chenxi | 4bd7b83b57 | |
Chenxi | 6e5382584c | |
chenxwh | 0da8ee4b7a | |
Chenxi | e3fc926ca4 | |
chenxwh | 0c6942fd2a | |
chenxwh | f649f9216b | |
Chenxi | 1e2f8391a7 | |
chenxwh | b8eca5a2d4 | |
chenxwh | 023d4b1c6c | |
chenxwh | 49a648fa54 |
Binary file not shown.
|
@ -0,0 +1,8 @@
|
|||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.0.1
|
||||
torchaudio==2.0.2
|
||||
xformers==0.0.22
|
||||
tensorboard==2.16.2
|
||||
phonemizer==3.2.1
|
||||
datasets==2.16.0
|
||||
torchmetrics==0.11.1
|
Binary file not shown.
|
@ -0,0 +1,8 @@
|
|||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.0.1
|
||||
torchaudio==2.0.2
|
||||
xformers==0.0.22
|
||||
tensorboard==2.16.2
|
||||
phonemizer==3.2.1
|
||||
datasets==2.16.0
|
||||
torchmetrics==0.11.1
|
Binary file not shown.
|
@ -0,0 +1,8 @@
|
|||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.0.1
|
||||
torchaudio==2.0.2
|
||||
xformers==0.0.22
|
||||
tensorboard==2.16.2
|
||||
phonemizer==3.2.1
|
||||
datasets==2.16.0
|
||||
torchmetrics==0.11.1
|
Binary file not shown.
|
@ -0,0 +1,8 @@
|
|||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.0.1
|
||||
torchaudio==2.0.2
|
||||
xformers==0.0.22
|
||||
tensorboard==2.16.2
|
||||
phonemizer==3.2.1
|
||||
datasets==2.16.0
|
||||
torchmetrics==0.11.1
|
Binary file not shown.
|
@ -0,0 +1,8 @@
|
|||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.0.1
|
||||
torchaudio==2.0.2
|
||||
xformers==0.0.22
|
||||
tensorboard==2.16.2
|
||||
phonemizer==3.2.1
|
||||
datasets==2.16.0
|
||||
torchmetrics==0.11.1
|
Binary file not shown.
|
@ -0,0 +1,8 @@
|
|||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.0.1
|
||||
torchaudio==2.0.2
|
||||
xformers==0.0.22
|
||||
tensorboard==2.16.2
|
||||
phonemizer==3.2.1
|
||||
datasets==2.16.0
|
||||
torchmetrics==0.11.1
|
|
@ -0,0 +1,17 @@
|
|||
# The .dockerignore file excludes files from the container build process.
|
||||
#
|
||||
# https://docs.docker.com/engine/reference/builder/#dockerignore-file
|
||||
|
||||
# Exclude Git files
|
||||
.git
|
||||
.github
|
||||
.gitignore
|
||||
|
||||
# Exclude Python cache files
|
||||
__pycache__
|
||||
.mypy_cache
|
||||
.pytest_cache
|
||||
.ruff_cache
|
||||
|
||||
# Exclude Python virtual environment
|
||||
/venv
|
|
@ -27,4 +27,6 @@ thumbs.db
|
|||
src/audiocraft
|
||||
|
||||
!/demo/
|
||||
!/demo/*
|
||||
!/demo/*
|
||||
|
||||
.cog/tmp/*
|
|
@ -1,5 +1,6 @@
|
|||
# VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild
|
||||
[![Paper](https://img.shields.io/badge/arXiv-2301.12503-brightgreen.svg?style=flat-square)](https://jasonppy.github.io/assets/pdfs/VoiceCraft.pdf) [![githubio](https://img.shields.io/badge/GitHub.io-Audio_Samples-blue?logo=Github&style=flat-square)](https://jasonppy.github.io/VoiceCraft_web/) [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/pyp1/VoiceCraft_gradio) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1IOjpglQyMTO2C3Y94LD9FY0Ocn-RJRg6?usp=sharing)
|
||||
[![Paper](https://img.shields.io/badge/arXiv-2301.12503-brightgreen.svg?style=flat-square)](https://jasonppy.github.io/assets/pdfs/VoiceCraft.pdf) [![githubio](https://img.shields.io/badge/GitHub.io-Audio_Samples-blue?logo=Github&style=flat-square)](https://jasonppy.github.io/VoiceCraft_web/) [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/pyp1/VoiceCraft_gradio) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1IOjpglQyMTO2C3Y94LD9FY0Ocn-RJRg6?usp=sharing) [![Replicate](https://replicate.com/cjwbw/voicecraft/badge)](https://replicate.com/cjwbw/voicecraft)
|
||||
|
||||
|
||||
### TL;DR
|
||||
VoiceCraft is a token infilling neural codec language model, that achieves state-of-the-art performance on both **speech editing** and **zero-shot text-to-speech (TTS)** on in-the-wild data including audiobooks, internet videos, and podcasts.
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85
|
|
@ -0,0 +1,25 @@
|
|||
# Configuration for Cog ⚙️
|
||||
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
|
||||
|
||||
build:
|
||||
gpu: true
|
||||
system_packages:
|
||||
- libgl1-mesa-glx
|
||||
- libglib2.0-0
|
||||
- ffmpeg
|
||||
- espeak-ng
|
||||
python_version: "3.11"
|
||||
python_packages:
|
||||
- torch==2.1.0
|
||||
- torchaudio==2.1.0
|
||||
- xformers
|
||||
- phonemizer==3.2.1
|
||||
- whisperx==3.1.1
|
||||
- openai-whisper>=20231117
|
||||
run:
|
||||
# - git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft
|
||||
- pip install -e git+https://github.com/facebookresearch/audiocraft.git@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft # use "git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft" instead if hits audiocraft import error
|
||||
- pip install "pydantic<2.0.0"
|
||||
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
|
||||
- mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth" "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth"
|
||||
predict: "predict.py:Predictor"
|
|
@ -77,8 +77,14 @@ class WhisperxModel:
|
|||
def load_models(whisper_backend_name, whisper_model_name, alignment_model_name, voicecraft_model_name):
|
||||
global transcribe_model, align_model, voicecraft_model
|
||||
|
||||
if voicecraft_model_name == "giga330M_TTSEnhanced":
|
||||
if voicecraft_model_name == "330M":
|
||||
voicecraft_model_name = "giga330M"
|
||||
elif voicecraft_model_name == "830M":
|
||||
voicecraft_model_name = "giga830M"
|
||||
elif voicecraft_model_name == "330M_TTSEnhanced":
|
||||
voicecraft_model_name = "gigaHalfLibri330M_TTSEnhanced_max16s"
|
||||
elif voicecraft_model_name == "830M_TTSEnhanced":
|
||||
voicecraft_model_name = "830M_TTSEnhanced"
|
||||
|
||||
if alignment_model_name is not None:
|
||||
align_model = WhisperxAlignModel()
|
||||
|
@ -434,19 +440,19 @@ def get_app():
|
|||
with gr.Column(scale=5):
|
||||
with gr.Accordion("Select models", open=False) as models_selector:
|
||||
with gr.Row():
|
||||
voicecraft_model_choice = gr.Radio(label="VoiceCraft model", value="giga830M",
|
||||
choices=["giga330M", "giga830M", "giga330M_TTSEnhanced"])
|
||||
whisper_backend_choice = gr.Radio(label="Whisper backend", value="whisperX", choices=["whisper", "whisperX"])
|
||||
voicecraft_model_choice = gr.Radio(label="VoiceCraft model", value="830M_TTSEnhanced",
|
||||
choices=["330M", "830M", "330M_TTSEnhanced", "830M_TTSEnhanced"])
|
||||
whisper_backend_choice = gr.Radio(label="Whisper backend", value="whisperX", choices=["whisperX", "whisper"])
|
||||
whisper_model_choice = gr.Radio(label="Whisper model", value="base.en",
|
||||
choices=[None, "base.en", "small.en", "medium.en", "large"])
|
||||
align_model_choice = gr.Radio(label="Forced alignment model", value="whisperX", choices=[None, "whisperX"])
|
||||
align_model_choice = gr.Radio(label="Forced alignment model", value="whisperX", choices=["whisperX", None])
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
input_audio = gr.Audio(value=f"{DEMO_PATH}/84_121550_000074_000000.wav", label="Input Audio", type="filepath", interactive=True)
|
||||
with gr.Group():
|
||||
original_transcript = gr.Textbox(label="Original transcript", lines=5, value=demo_original_transcript,
|
||||
info="Use whisper model to get the transcript. Fix and align it if necessary.")
|
||||
info="Use whisperx model to get the transcript. Fix and align it if necessary.")
|
||||
with gr.Accordion("Word start time", open=False):
|
||||
transcript_with_start_time = gr.Textbox(label="Start time", lines=5, interactive=False, info="Start time before each word")
|
||||
with gr.Accordion("Word end time", open=False):
|
||||
|
@ -499,7 +505,7 @@ def get_app():
|
|||
with gr.Accordion("Generation Parameters - change these if you are unhappy with the generation", open=False):
|
||||
stop_repetition = gr.Radio(label="stop_repetition", choices=[-1, 1, 2, 3, 4], value=3,
|
||||
info="if there are long silence in the generated audio, reduce the stop_repetition to 2 or 1. -1 = disabled")
|
||||
sample_batch_size = gr.Number(label="speech rate", value=4, precision=0,
|
||||
sample_batch_size = gr.Number(label="speech rate", value=3, precision=0,
|
||||
info="The higher the number, the faster the output will be. "
|
||||
"Under the hood, the model will generate this many samples and choose the shortest one. "
|
||||
"For giga330M_TTSEnhanced, 1 or 2 should be fine since the model is trained to do TTS.")
|
||||
|
|
|
@ -0,0 +1,432 @@
|
|||
# Prediction interface for Cog ⚙️
|
||||
# https://github.com/replicate/cog/blob/main/docs/python.md
|
||||
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
import getpass
|
||||
import shutil
|
||||
import subprocess
|
||||
import torch
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
from whisper.model import Whisper, ModelDimensions
|
||||
from whisper.tokenizer import get_tokenizer
|
||||
from cog import BasePredictor, Input, Path, BaseModel
|
||||
|
||||
os.environ["USER"] = getpass.getuser()
|
||||
|
||||
from data.tokenizer import (
|
||||
AudioTokenizer,
|
||||
TextTokenizer,
|
||||
)
|
||||
from models import voicecraft
|
||||
from inference_tts_scale import inference_one_sample
|
||||
from edit_utils import get_span
|
||||
from inference_speech_editing_scale import (
|
||||
inference_one_sample as inference_one_sample_editing,
|
||||
)
|
||||
|
||||
|
||||
MODEL_URL = "https://weights.replicate.delivery/default/pyp1/VoiceCraft-models.tar" # all the models are cached and uploaded to replicate.delivery for faster booting
|
||||
MODEL_CACHE = "model_cache"
|
||||
|
||||
|
||||
class ModelOutput(BaseModel):
|
||||
whisper_transcript_orig_audio: str
|
||||
generated_audio: Path
|
||||
|
||||
|
||||
class WhisperxAlignModel:
|
||||
def __init__(self):
|
||||
from whisperx import load_align_model
|
||||
|
||||
self.model, self.metadata = load_align_model(
|
||||
language_code="en", device="cuda:0"
|
||||
)
|
||||
|
||||
def align(self, segments, audio_path):
|
||||
from whisperx import align, load_audio
|
||||
|
||||
audio = load_audio(audio_path)
|
||||
return align(
|
||||
segments,
|
||||
self.model,
|
||||
self.metadata,
|
||||
audio,
|
||||
device="cuda:0",
|
||||
return_char_alignments=False,
|
||||
)["segments"]
|
||||
|
||||
|
||||
class WhisperxModel:
|
||||
def __init__(self, model_name, align_model: WhisperxAlignModel, device="cuda"):
|
||||
from whisperx import load_model
|
||||
|
||||
# the model weights are cached from Systran/faster-whisper-base.en etc
|
||||
self.model = load_model(
|
||||
model_name,
|
||||
device,
|
||||
asr_options={
|
||||
"suppress_numerals": True,
|
||||
"max_new_tokens": None,
|
||||
"clip_timestamps": None,
|
||||
"hallucination_silence_threshold": None,
|
||||
},
|
||||
)
|
||||
self.align_model = align_model
|
||||
|
||||
def transcribe(self, audio_path):
|
||||
segments = self.model.transcribe(audio_path, language="en", batch_size=8)[
|
||||
"segments"
|
||||
]
|
||||
return self.align_model.align(segments, audio_path)
|
||||
|
||||
|
||||
class WhisperModel:
|
||||
def __init__(self, model_cache, model_name="base.en", device="cuda"):
|
||||
|
||||
# the model weights are cached from https://github.com/openai/whisper/blob/ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab/whisper/__init__.py#L17
|
||||
with open(f"{model_cache}/{model_name}.pt", "rb") as fp:
|
||||
checkpoint = torch.load(fp, map_location="cpu")
|
||||
dims = ModelDimensions(**checkpoint["dims"])
|
||||
self.model = Whisper(dims)
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
self.model.to(device)
|
||||
|
||||
tokenizer = get_tokenizer(multilingual=False)
|
||||
self.supress_tokens = [-1] + [
|
||||
i
|
||||
for i in range(tokenizer.eot)
|
||||
if all(c in "0123456789" for c in tokenizer.decode([i]).removeprefix(" "))
|
||||
]
|
||||
|
||||
def transcribe(self, audio_path):
|
||||
return self.model.transcribe(
|
||||
audio_path, suppress_tokens=self.supress_tokens, word_timestamps=True
|
||||
)["segments"]
|
||||
|
||||
|
||||
def download_weights(url, dest):
|
||||
start = time.time()
|
||||
print("downloading url: ", url)
|
||||
print("downloading to: ", dest)
|
||||
subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
|
||||
print("downloading took: ", time.time() - start)
|
||||
|
||||
|
||||
class Predictor(BasePredictor):
|
||||
def setup(self):
|
||||
"""Load the model into memory to make running multiple predictions efficient"""
|
||||
self.device = "cuda"
|
||||
|
||||
if not os.path.exists(MODEL_CACHE):
|
||||
download_weights(MODEL_URL, MODEL_CACHE)
|
||||
|
||||
encodec_fn = f"{MODEL_CACHE}/encodec_4cb2048_giga.th"
|
||||
self.models, self.ckpt, self.phn2num = {}, {}, {}
|
||||
for voicecraft_name in [
|
||||
"giga830M.pth",
|
||||
"giga330M.pth",
|
||||
"gigaHalfLibri330M_TTSEnhanced_max16s.pth",
|
||||
]:
|
||||
ckpt_fn = f"{MODEL_CACHE}/{voicecraft_name}"
|
||||
|
||||
self.ckpt[voicecraft_name] = torch.load(ckpt_fn, map_location="cpu")
|
||||
self.models[voicecraft_name] = voicecraft.VoiceCraft(
|
||||
self.ckpt[voicecraft_name]["config"]
|
||||
)
|
||||
self.models[voicecraft_name].load_state_dict(
|
||||
self.ckpt[voicecraft_name]["model"]
|
||||
)
|
||||
self.models[voicecraft_name].to(self.device)
|
||||
self.models[voicecraft_name].eval()
|
||||
|
||||
self.phn2num[voicecraft_name] = self.ckpt[voicecraft_name]["phn2num"]
|
||||
|
||||
self.text_tokenizer = TextTokenizer(backend="espeak")
|
||||
self.audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=self.device)
|
||||
self.transcribe_models_whisper = {
|
||||
k: WhisperModel(MODEL_CACHE, k, self.device)
|
||||
for k in ["base.en", "small.en", "medium.en"]
|
||||
}
|
||||
|
||||
align_model = WhisperxAlignModel()
|
||||
self.transcribe_models_whisperx = {
|
||||
k: WhisperxModel(f"{MODEL_CACHE}/whisperx_{k.split('.')[0]}", align_model)
|
||||
for k in ["base.en", "small.en", "medium.en"]
|
||||
}
|
||||
|
||||
def predict(
|
||||
self,
|
||||
task: str = Input(
|
||||
description="Choose a task",
|
||||
choices=[
|
||||
"speech_editing-substitution",
|
||||
"speech_editing-insertion",
|
||||
"speech_editing-deletion",
|
||||
"zero-shot text-to-speech",
|
||||
],
|
||||
default="zero-shot text-to-speech",
|
||||
),
|
||||
voicecraft_model: str = Input(
|
||||
description="Choose a model",
|
||||
choices=["giga830M.pth", "giga330M.pth", "giga330M_TTSEnhanced.pth"],
|
||||
default="giga330M_TTSEnhanced.pth",
|
||||
),
|
||||
orig_audio: Path = Input(
|
||||
description="Original audio file. WhisperX small.en model will be used for transcription"
|
||||
),
|
||||
orig_transcript: str = Input(
|
||||
description="Optionally provide the transcript of the input audio. Leave it blank to use the whisper model below to generate the transcript. Inaccurate transcription may lead to error TTS or speech editing",
|
||||
default="",
|
||||
),
|
||||
whisper_model: str = Input(
|
||||
description="If orig_transcript is not provided above, choose a Whisper or WhisperX model. WhisperX model contains extra alignment steps. Inaccurate transcription may lead to error TTS or speech editing. You can modify the generated transcript and provide it directly to ",
|
||||
choices=[
|
||||
"whisper-base.en",
|
||||
"whisper-small.en",
|
||||
"whisper-medium.en",
|
||||
"whisperx-base.en",
|
||||
"whisperx-small.en",
|
||||
"whisperx-medium.en",
|
||||
],
|
||||
default="whisper-base.en",
|
||||
),
|
||||
target_transcript: str = Input(
|
||||
description="Transcript of the target audio file",
|
||||
),
|
||||
cut_off_sec: float = Input(
|
||||
description="Only used for for zero-shot text-to-speech task. The first seconds of the original audio that are used for zero-shot text-to-speech. 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec",
|
||||
default=3.01,
|
||||
),
|
||||
kvcache: int = Input(
|
||||
description="Set to 0 to use less VRAM, but with slower inference",
|
||||
default=1,
|
||||
),
|
||||
left_margin: float = Input(
|
||||
description="Margin to the left of the editing segment",
|
||||
default=0.08,
|
||||
),
|
||||
right_margin: float = Input(
|
||||
description="Margin to the right of the editing segment",
|
||||
default=0.08,
|
||||
),
|
||||
temperature: float = Input(
|
||||
description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic. Do not recommend to change",
|
||||
default=1,
|
||||
),
|
||||
top_p: float = Input(
|
||||
description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
default=0.8,
|
||||
),
|
||||
stop_repetition: int = Input(
|
||||
default=-1,
|
||||
description=" -1 means do not adjust prob of silence tokens. if there are long silence or unnaturally stretched words, increase sample_batch_size to 2, 3 or even 4",
|
||||
),
|
||||
sample_batch_size: int = Input(
|
||||
description="The higher the number, the faster the output will be. Under the hood, the model will generate this many samples and choose the shortest one",
|
||||
default=4,
|
||||
),
|
||||
seed: int = Input(
|
||||
description="Random seed. Leave blank to randomize the seed", default=None
|
||||
),
|
||||
) -> ModelOutput:
|
||||
"""Run a single prediction on the model"""
|
||||
|
||||
if seed is None:
|
||||
seed = int.from_bytes(os.urandom(2), "big")
|
||||
print(f"Using seed: {seed}")
|
||||
|
||||
seed_everything(seed)
|
||||
|
||||
whisper_model, whisper_model_size = whisper_model.split("-")
|
||||
|
||||
if whisper_model == "whisper":
|
||||
segments = self.transcribe_models_whisper[whisper_model_size].transcribe(
|
||||
str(orig_audio)
|
||||
)
|
||||
else:
|
||||
segments = self.transcribe_models_whisperx[whisper_model_size].transcribe(
|
||||
str(orig_audio)
|
||||
)
|
||||
|
||||
state = get_transcribe_state(segments)
|
||||
|
||||
whisper_transcript = state["transcript"].strip()
|
||||
|
||||
if len(orig_transcript.strip()) == 0:
|
||||
orig_transcript = whisper_transcript
|
||||
|
||||
print(f"The transcript from the Whisper model: {whisper_transcript}")
|
||||
|
||||
temp_folder = "exp_dir"
|
||||
if os.path.exists(temp_folder):
|
||||
shutil.rmtree(temp_folder)
|
||||
|
||||
os.makedirs(temp_folder)
|
||||
|
||||
filename = "orig_audio"
|
||||
audio_fn = str(orig_audio)
|
||||
|
||||
info = torchaudio.info(audio_fn)
|
||||
audio_dur = info.num_frames / info.sample_rate
|
||||
|
||||
# hyperparameters for inference
|
||||
codec_audio_sr = 16000
|
||||
codec_sr = 50
|
||||
top_k = 0
|
||||
silence_tokens = [1388, 1898, 131]
|
||||
|
||||
if voicecraft_model == "giga330M_TTSEnhanced.pth":
|
||||
voicecraft_model = "gigaHalfLibri330M_TTSEnhanced_max16s.pth"
|
||||
|
||||
if task == "zero-shot text-to-speech":
|
||||
assert (
|
||||
cut_off_sec < audio_dur
|
||||
), f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}"
|
||||
prompt_end_frame = int(cut_off_sec * info.sample_rate)
|
||||
|
||||
idx = find_closest_cut_off_word(state["word_bounds"], cut_off_sec)
|
||||
orig_transcript_until_cutoff_time = "".join(
|
||||
[word_bound["word"] for word_bound in state["word_bounds"][:idx]]
|
||||
)
|
||||
else:
|
||||
edit_type = task.split("-")[-1]
|
||||
orig_span, new_span = get_span(
|
||||
orig_transcript, target_transcript, edit_type
|
||||
)
|
||||
if orig_span[0] > orig_span[1]:
|
||||
RuntimeError(f"example {audio_fn} failed")
|
||||
if orig_span[0] == orig_span[1]:
|
||||
orig_span_save = [orig_span[0]]
|
||||
else:
|
||||
orig_span_save = orig_span
|
||||
if new_span[0] == new_span[1]:
|
||||
new_span_save = [new_span[0]]
|
||||
else:
|
||||
new_span_save = new_span
|
||||
orig_span_save = ",".join([str(item) for item in orig_span_save])
|
||||
new_span_save = ",".join([str(item) for item in new_span_save])
|
||||
|
||||
start, end = get_mask_interval_from_word_bounds(
|
||||
state["word_bounds"], orig_span_save, edit_type
|
||||
)
|
||||
|
||||
# span in codec frames
|
||||
morphed_span = (
|
||||
max(start - left_margin, 1 / codec_sr),
|
||||
min(end + right_margin, audio_dur),
|
||||
) # in seconds
|
||||
mask_interval = [
|
||||
[round(morphed_span[0] * codec_sr), round(morphed_span[1] * codec_sr)]
|
||||
]
|
||||
mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
|
||||
|
||||
decode_config = {
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"stop_repetition": stop_repetition,
|
||||
"kvcache": kvcache,
|
||||
"codec_audio_sr": codec_audio_sr,
|
||||
"codec_sr": codec_sr,
|
||||
"silence_tokens": silence_tokens,
|
||||
}
|
||||
|
||||
if task == "zero-shot text-to-speech":
|
||||
decode_config["sample_batch_size"] = sample_batch_size
|
||||
_, gen_audio = inference_one_sample(
|
||||
self.models[voicecraft_model],
|
||||
self.ckpt[voicecraft_model]["config"],
|
||||
self.phn2num[voicecraft_model],
|
||||
self.text_tokenizer,
|
||||
self.audio_tokenizer,
|
||||
audio_fn,
|
||||
orig_transcript_until_cutoff_time.strip()
|
||||
+ ""
|
||||
+ target_transcript.strip(),
|
||||
self.device,
|
||||
decode_config,
|
||||
prompt_end_frame,
|
||||
)
|
||||
else:
|
||||
_, gen_audio = inference_one_sample_editing(
|
||||
self.models[voicecraft_model],
|
||||
self.ckpt[voicecraft_model]["config"],
|
||||
self.phn2num[voicecraft_model],
|
||||
self.text_tokenizer,
|
||||
self.audio_tokenizer,
|
||||
audio_fn,
|
||||
target_transcript,
|
||||
mask_interval,
|
||||
self.device,
|
||||
decode_config,
|
||||
)
|
||||
|
||||
# save segments for comparison
|
||||
gen_audio = gen_audio[0].cpu()
|
||||
|
||||
out = "/tmp/out.wav"
|
||||
torchaudio.save(out, gen_audio, codec_audio_sr)
|
||||
return ModelOutput(
|
||||
generated_audio=Path(out), whisper_transcript_orig_audio=whisper_transcript
|
||||
)
|
||||
|
||||
|
||||
def seed_everything(seed):
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def get_transcribe_state(segments):
|
||||
words_info = [word_info for segment in segments for word_info in segment["words"]]
|
||||
return {
|
||||
"transcript": " ".join([segment["text"].strip() for segment in segments]),
|
||||
"word_bounds": [
|
||||
{"word": word["word"], "start": word["start"], "end": word["end"]}
|
||||
for word in words_info
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def find_closest_cut_off_word(word_bounds, cut_off_sec):
|
||||
min_distance = float("inf")
|
||||
|
||||
for i, word_bound in enumerate(word_bounds):
|
||||
distance = abs(word_bound["start"] - cut_off_sec)
|
||||
|
||||
if distance < min_distance:
|
||||
min_distance = distance
|
||||
|
||||
if word_bound["end"] > cut_off_sec:
|
||||
break
|
||||
|
||||
return i
|
||||
|
||||
|
||||
def get_mask_interval_from_word_bounds(word_bounds, word_span_ind, editType):
|
||||
tmp = word_span_ind.split(",")
|
||||
s, e = int(tmp[0]), int(tmp[-1])
|
||||
start = None
|
||||
for j, item in enumerate(word_bounds):
|
||||
if j == s:
|
||||
if editType == "insertion":
|
||||
start = float(item["end"])
|
||||
else:
|
||||
start = float(item["start"])
|
||||
if j == e:
|
||||
if editType == "insertion":
|
||||
end = float(item["start"])
|
||||
else:
|
||||
end = float(item["end"])
|
||||
assert start != None
|
||||
break
|
||||
return (start, end)
|
Loading…
Reference in New Issue