This commit is contained in:
chenxwh 2024-04-21 22:30:56 +00:00
parent 2a2ee984b6
commit 87f4fa5d21
3 changed files with 23 additions and 68 deletions

@ -1 +0,0 @@
Subproject commit 69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85

View File

@ -17,8 +17,7 @@ build:
- whisperx==3.1.1
- openai-whisper>=20231117
run:
# - git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft
- pip install -e git+https://github.com/facebookresearch/audiocraft.git@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft # use "git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft" instead if hits audiocraft import error
- git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft
- pip install "pydantic<2.0.0"
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
- mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth" "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth"

View File

@ -10,8 +10,6 @@ import subprocess
import torch
import numpy as np
import torchaudio
from whisper.model import Whisper, ModelDimensions
from whisper.tokenizer import get_tokenizer
from cog import BasePredictor, Input, Path, BaseModel
os.environ["USER"] = getpass.getuser()
@ -83,30 +81,6 @@ class WhisperxModel:
return self.align_model.align(segments, audio_path)
class WhisperModel:
def __init__(self, model_cache, model_name="base.en", device="cuda"):
# the model weights are cached from https://github.com/openai/whisper/blob/ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab/whisper/__init__.py#L17
with open(f"{model_cache}/{model_name}.pt", "rb") as fp:
checkpoint = torch.load(fp, map_location="cpu")
dims = ModelDimensions(**checkpoint["dims"])
self.model = Whisper(dims)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.model.to(device)
tokenizer = get_tokenizer(multilingual=False)
self.supress_tokens = [-1] + [
i
for i in range(tokenizer.eot)
if all(c in "0123456789" for c in tokenizer.decode([i]).removeprefix(" "))
]
def transcribe(self, audio_path):
return self.model.transcribe(
audio_path, suppress_tokens=self.supress_tokens, word_timestamps=True
)["segments"]
def download_weights(url, dest):
start = time.time()
print("downloading url: ", url)
@ -146,13 +120,9 @@ class Predictor(BasePredictor):
self.text_tokenizer = TextTokenizer(backend="espeak")
self.audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=self.device)
self.transcribe_models_whisper = {
k: WhisperModel(MODEL_CACHE, k, self.device)
for k in ["base.en", "small.en", "medium.en"]
}
align_model = WhisperxAlignModel()
self.transcribe_models_whisperx = {
self.transcribe_models = {
k: WhisperxModel(f"{MODEL_CACHE}/whisperx_{k.split('.')[0]}", align_model)
for k in ["base.en", "small.en", "medium.en"]
}
@ -174,24 +144,19 @@ class Predictor(BasePredictor):
choices=["giga830M.pth", "giga330M.pth", "giga330M_TTSEnhanced.pth"],
default="giga330M_TTSEnhanced.pth",
),
orig_audio: Path = Input(
description="Original audio file. WhisperX small.en model will be used for transcription"
),
orig_audio: Path = Input(description="Original audio file"),
orig_transcript: str = Input(
description="Optionally provide the transcript of the input audio. Leave it blank to use the whisper model below to generate the transcript. Inaccurate transcription may lead to error TTS or speech editing",
description="Optionally provide the transcript of the input audio. Leave it blank to use the WhisperX model below to generate the transcript. Inaccurate transcription may lead to error TTS or speech editing",
default="",
),
whisper_model: str = Input(
description="If orig_transcript is not provided above, choose a Whisper or WhisperX model. WhisperX model contains extra alignment steps. Inaccurate transcription may lead to error TTS or speech editing. You can modify the generated transcript and provide it directly to ",
whisperx_model: str = Input(
description="If orig_transcript is not provided above, choose a WhisperX model for generating the transcript. Inaccurate transcription may lead to error TTS or speech editing. You can modify the generated transcript and provide it directly to orig_transcript above",
choices=[
"whisper-base.en",
"whisper-small.en",
"whisper-medium.en",
"whisperx-base.en",
"whisperx-small.en",
"whisperx-medium.en",
"base.en",
"small.en",
"medium.en",
],
default="whisper-base.en",
default="base.en",
),
target_transcript: str = Input(
description="Transcript of the target audio file",
@ -202,6 +167,7 @@ class Predictor(BasePredictor):
),
kvcache: int = Input(
description="Set to 0 to use less VRAM, but with slower inference",
choices=[0, 1],
default=1,
),
left_margin: float = Input(
@ -217,17 +183,15 @@ class Predictor(BasePredictor):
default=1,
),
top_p: float = Input(
description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens",
ge=0.0,
le=1.0,
default=0.8,
description="Default value for TTS is 0.9, and 0.8 for speech editing",
default=0.9,
),
stop_repetition: int = Input(
default=-1,
description=" -1 means do not adjust prob of silence tokens. if there are long silence or unnaturally stretched words, increase sample_batch_size to 2, 3 or even 4",
default=3,
description="Default value for TTS is 3, and -1 for speech editing. -1 means do not adjust prob of silence tokens. if there are long silence or unnaturally stretched words, increase sample_batch_size to 2, 3 or even 4",
),
sample_batch_size: int = Input(
description="The higher the number, the faster the output will be. Under the hood, the model will generate this many samples and choose the shortest one",
description="Default value for TTS is 4, and 1 for speech editing. The higher the number, the faster the output will be. Under the hood, the model will generate this many samples and choose the shortest one",
default=4,
),
seed: int = Input(
@ -242,16 +206,9 @@ class Predictor(BasePredictor):
seed_everything(seed)
whisper_model, whisper_model_size = whisper_model.split("-")
if whisper_model == "whisper":
segments = self.transcribe_models_whisper[whisper_model_size].transcribe(
str(orig_audio)
)
else:
segments = self.transcribe_models_whisperx[whisper_model_size].transcribe(
str(orig_audio)
)
segments = self.transcribe_models[whisperx_model].transcribe(
str(orig_audio)
)
state = get_transcribe_state(segments)
@ -290,8 +247,8 @@ class Predictor(BasePredictor):
prompt_end_frame = int(cut_off_sec * info.sample_rate)
idx = find_closest_cut_off_word(state["word_bounds"], cut_off_sec)
orig_transcript_until_cutoff_time = "".join(
[word_bound["word"] for word_bound in state["word_bounds"][:idx]]
orig_transcript_until_cutoff_time = " ".join(
[word_bound["word"] for word_bound in state["word_bounds"][: idx + 1]]
)
else:
edit_type = task.split("-")[-1]
@ -346,7 +303,7 @@ class Predictor(BasePredictor):
self.audio_tokenizer,
audio_fn,
orig_transcript_until_cutoff_time.strip()
+ ""
+ " "
+ target_transcript.strip(),
self.device,
decode_config,
@ -427,6 +384,6 @@ def get_mask_interval_from_word_bounds(word_bounds, word_span_ind, editType):
end = float(item["start"])
else:
end = float(item["end"])
assert start != None
assert start is not None
break
return (start, end)