update
This commit is contained in:
parent
2a2ee984b6
commit
87f4fa5d21
|
@ -1 +0,0 @@
|
|||
Subproject commit 69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85
|
3
cog.yaml
3
cog.yaml
|
@ -17,8 +17,7 @@ build:
|
|||
- whisperx==3.1.1
|
||||
- openai-whisper>=20231117
|
||||
run:
|
||||
# - git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft
|
||||
- pip install -e git+https://github.com/facebookresearch/audiocraft.git@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft # use "git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft" instead if hits audiocraft import error
|
||||
- git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft
|
||||
- pip install "pydantic<2.0.0"
|
||||
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
|
||||
- mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth" "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth"
|
||||
|
|
87
predict.py
87
predict.py
|
@ -10,8 +10,6 @@ import subprocess
|
|||
import torch
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
from whisper.model import Whisper, ModelDimensions
|
||||
from whisper.tokenizer import get_tokenizer
|
||||
from cog import BasePredictor, Input, Path, BaseModel
|
||||
|
||||
os.environ["USER"] = getpass.getuser()
|
||||
|
@ -83,30 +81,6 @@ class WhisperxModel:
|
|||
return self.align_model.align(segments, audio_path)
|
||||
|
||||
|
||||
class WhisperModel:
|
||||
def __init__(self, model_cache, model_name="base.en", device="cuda"):
|
||||
|
||||
# the model weights are cached from https://github.com/openai/whisper/blob/ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab/whisper/__init__.py#L17
|
||||
with open(f"{model_cache}/{model_name}.pt", "rb") as fp:
|
||||
checkpoint = torch.load(fp, map_location="cpu")
|
||||
dims = ModelDimensions(**checkpoint["dims"])
|
||||
self.model = Whisper(dims)
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
self.model.to(device)
|
||||
|
||||
tokenizer = get_tokenizer(multilingual=False)
|
||||
self.supress_tokens = [-1] + [
|
||||
i
|
||||
for i in range(tokenizer.eot)
|
||||
if all(c in "0123456789" for c in tokenizer.decode([i]).removeprefix(" "))
|
||||
]
|
||||
|
||||
def transcribe(self, audio_path):
|
||||
return self.model.transcribe(
|
||||
audio_path, suppress_tokens=self.supress_tokens, word_timestamps=True
|
||||
)["segments"]
|
||||
|
||||
|
||||
def download_weights(url, dest):
|
||||
start = time.time()
|
||||
print("downloading url: ", url)
|
||||
|
@ -146,13 +120,9 @@ class Predictor(BasePredictor):
|
|||
|
||||
self.text_tokenizer = TextTokenizer(backend="espeak")
|
||||
self.audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=self.device)
|
||||
self.transcribe_models_whisper = {
|
||||
k: WhisperModel(MODEL_CACHE, k, self.device)
|
||||
for k in ["base.en", "small.en", "medium.en"]
|
||||
}
|
||||
|
||||
align_model = WhisperxAlignModel()
|
||||
self.transcribe_models_whisperx = {
|
||||
self.transcribe_models = {
|
||||
k: WhisperxModel(f"{MODEL_CACHE}/whisperx_{k.split('.')[0]}", align_model)
|
||||
for k in ["base.en", "small.en", "medium.en"]
|
||||
}
|
||||
|
@ -174,24 +144,19 @@ class Predictor(BasePredictor):
|
|||
choices=["giga830M.pth", "giga330M.pth", "giga330M_TTSEnhanced.pth"],
|
||||
default="giga330M_TTSEnhanced.pth",
|
||||
),
|
||||
orig_audio: Path = Input(
|
||||
description="Original audio file. WhisperX small.en model will be used for transcription"
|
||||
),
|
||||
orig_audio: Path = Input(description="Original audio file"),
|
||||
orig_transcript: str = Input(
|
||||
description="Optionally provide the transcript of the input audio. Leave it blank to use the whisper model below to generate the transcript. Inaccurate transcription may lead to error TTS or speech editing",
|
||||
description="Optionally provide the transcript of the input audio. Leave it blank to use the WhisperX model below to generate the transcript. Inaccurate transcription may lead to error TTS or speech editing",
|
||||
default="",
|
||||
),
|
||||
whisper_model: str = Input(
|
||||
description="If orig_transcript is not provided above, choose a Whisper or WhisperX model. WhisperX model contains extra alignment steps. Inaccurate transcription may lead to error TTS or speech editing. You can modify the generated transcript and provide it directly to ",
|
||||
whisperx_model: str = Input(
|
||||
description="If orig_transcript is not provided above, choose a WhisperX model for generating the transcript. Inaccurate transcription may lead to error TTS or speech editing. You can modify the generated transcript and provide it directly to orig_transcript above",
|
||||
choices=[
|
||||
"whisper-base.en",
|
||||
"whisper-small.en",
|
||||
"whisper-medium.en",
|
||||
"whisperx-base.en",
|
||||
"whisperx-small.en",
|
||||
"whisperx-medium.en",
|
||||
"base.en",
|
||||
"small.en",
|
||||
"medium.en",
|
||||
],
|
||||
default="whisper-base.en",
|
||||
default="base.en",
|
||||
),
|
||||
target_transcript: str = Input(
|
||||
description="Transcript of the target audio file",
|
||||
|
@ -202,6 +167,7 @@ class Predictor(BasePredictor):
|
|||
),
|
||||
kvcache: int = Input(
|
||||
description="Set to 0 to use less VRAM, but with slower inference",
|
||||
choices=[0, 1],
|
||||
default=1,
|
||||
),
|
||||
left_margin: float = Input(
|
||||
|
@ -217,17 +183,15 @@ class Predictor(BasePredictor):
|
|||
default=1,
|
||||
),
|
||||
top_p: float = Input(
|
||||
description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
default=0.8,
|
||||
description="Default value for TTS is 0.9, and 0.8 for speech editing",
|
||||
default=0.9,
|
||||
),
|
||||
stop_repetition: int = Input(
|
||||
default=-1,
|
||||
description=" -1 means do not adjust prob of silence tokens. if there are long silence or unnaturally stretched words, increase sample_batch_size to 2, 3 or even 4",
|
||||
default=3,
|
||||
description="Default value for TTS is 3, and -1 for speech editing. -1 means do not adjust prob of silence tokens. if there are long silence or unnaturally stretched words, increase sample_batch_size to 2, 3 or even 4",
|
||||
),
|
||||
sample_batch_size: int = Input(
|
||||
description="The higher the number, the faster the output will be. Under the hood, the model will generate this many samples and choose the shortest one",
|
||||
description="Default value for TTS is 4, and 1 for speech editing. The higher the number, the faster the output will be. Under the hood, the model will generate this many samples and choose the shortest one",
|
||||
default=4,
|
||||
),
|
||||
seed: int = Input(
|
||||
|
@ -242,16 +206,9 @@ class Predictor(BasePredictor):
|
|||
|
||||
seed_everything(seed)
|
||||
|
||||
whisper_model, whisper_model_size = whisper_model.split("-")
|
||||
|
||||
if whisper_model == "whisper":
|
||||
segments = self.transcribe_models_whisper[whisper_model_size].transcribe(
|
||||
str(orig_audio)
|
||||
)
|
||||
else:
|
||||
segments = self.transcribe_models_whisperx[whisper_model_size].transcribe(
|
||||
str(orig_audio)
|
||||
)
|
||||
segments = self.transcribe_models[whisperx_model].transcribe(
|
||||
str(orig_audio)
|
||||
)
|
||||
|
||||
state = get_transcribe_state(segments)
|
||||
|
||||
|
@ -290,8 +247,8 @@ class Predictor(BasePredictor):
|
|||
prompt_end_frame = int(cut_off_sec * info.sample_rate)
|
||||
|
||||
idx = find_closest_cut_off_word(state["word_bounds"], cut_off_sec)
|
||||
orig_transcript_until_cutoff_time = "".join(
|
||||
[word_bound["word"] for word_bound in state["word_bounds"][:idx]]
|
||||
orig_transcript_until_cutoff_time = " ".join(
|
||||
[word_bound["word"] for word_bound in state["word_bounds"][: idx + 1]]
|
||||
)
|
||||
else:
|
||||
edit_type = task.split("-")[-1]
|
||||
|
@ -346,7 +303,7 @@ class Predictor(BasePredictor):
|
|||
self.audio_tokenizer,
|
||||
audio_fn,
|
||||
orig_transcript_until_cutoff_time.strip()
|
||||
+ ""
|
||||
+ " "
|
||||
+ target_transcript.strip(),
|
||||
self.device,
|
||||
decode_config,
|
||||
|
@ -427,6 +384,6 @@ def get_mask_interval_from_word_bounds(word_bounds, word_span_ind, editType):
|
|||
end = float(item["start"])
|
||||
else:
|
||||
end = float(item["end"])
|
||||
assert start != None
|
||||
assert start is not None
|
||||
break
|
||||
return (start, end)
|
||||
|
|
Loading…
Reference in New Issue