mirror of
https://github.com/jasonppy/VoiceCraft.git
synced 2025-01-23 06:50:18 +01:00
150 lines
5.0 KiB
Python
150 lines
5.0 KiB
Python
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/data/tokenizer.py
|
||
# Copyright 2023 (authors: Feiteng Li)
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
import re
|
||
from dataclasses import asdict, dataclass
|
||
from typing import Any, Dict, List, Optional, Pattern, Union
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torchaudio
|
||
# from lhotse.features import FeatureExtractor
|
||
# from lhotse.utils import Seconds, compute_num_frames
|
||
from phonemizer.backend import EspeakBackend
|
||
from phonemizer.backend.espeak.language_switch import LanguageSwitch
|
||
from phonemizer.backend.espeak.words_mismatch import WordMismatch
|
||
from phonemizer.punctuation import Punctuation
|
||
from phonemizer.separator import Separator
|
||
|
||
|
||
|
||
class TextTokenizer:
|
||
"""Phonemize Text."""
|
||
|
||
def __init__(
|
||
self,
|
||
language="en-us",
|
||
backend="espeak",
|
||
separator=Separator(word="_", syllable="-", phone="|"),
|
||
preserve_punctuation=True,
|
||
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
||
with_stress: bool = False,
|
||
tie: Union[bool, str] = False,
|
||
language_switch: LanguageSwitch = "keep-flags",
|
||
words_mismatch: WordMismatch = "ignore",
|
||
) -> None:
|
||
phonemizer = EspeakBackend(
|
||
language,
|
||
punctuation_marks=punctuation_marks,
|
||
preserve_punctuation=preserve_punctuation,
|
||
with_stress=with_stress,
|
||
tie=tie,
|
||
language_switch=language_switch,
|
||
words_mismatch=words_mismatch,
|
||
)
|
||
|
||
self.backend = phonemizer
|
||
self.separator = separator
|
||
|
||
def to_list(self, phonemized: str) -> List[str]:
|
||
fields = []
|
||
for word in phonemized.split(self.separator.word):
|
||
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
|
||
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
|
||
fields.extend(
|
||
[p for p in pp if p != self.separator.phone]
|
||
+ [self.separator.word]
|
||
)
|
||
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
|
||
self.separator.phone
|
||
)
|
||
return fields[:-1]
|
||
|
||
def __call__(self, text, strip=True) -> List[List[str]]:
|
||
if isinstance(text, str):
|
||
text = [text]
|
||
|
||
phonemized = self.backend.phonemize(
|
||
text, separator=self.separator, strip=strip, njobs=1
|
||
)
|
||
return [self.to_list(p) for p in phonemized]
|
||
|
||
|
||
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
|
||
phonemes = tokenizer([text.strip()])
|
||
return phonemes[0] # k2symbols
|
||
|
||
def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int):
|
||
assert wav.shape[0] in [1, 2], "Audio must be mono or stereo."
|
||
if target_channels == 1:
|
||
wav = wav.mean(0, keepdim=True)
|
||
elif target_channels == 2:
|
||
*shape, _, length = wav.shape
|
||
wav = wav.expand(*shape, target_channels, length)
|
||
elif wav.shape[0] == 1:
|
||
wav = wav.expand(target_channels, -1)
|
||
wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
|
||
return wav
|
||
|
||
class AudioTokenizer:
|
||
"""EnCodec audio."""
|
||
|
||
def __init__(
|
||
self,
|
||
device: Any = None,
|
||
signature = None
|
||
) -> None:
|
||
from audiocraft.solvers import CompressionSolver
|
||
model = CompressionSolver.model_from_checkpoint(signature)
|
||
self.sample_rate = model.sample_rate
|
||
self.channels = model.channels
|
||
|
||
if not device:
|
||
device = torch.device("cpu")
|
||
if torch.cuda.is_available():
|
||
device = torch.device("cuda:0")
|
||
|
||
self._device = device
|
||
|
||
self.codec = model.to(device)
|
||
|
||
@property
|
||
def device(self):
|
||
return self._device
|
||
|
||
def encode(self, wav: torch.Tensor) -> torch.Tensor:
|
||
codes = self.codec.encode(wav.to(self.device))
|
||
return [(codes[0], None)]
|
||
|
||
def decode(self, frames: torch.Tensor) -> torch.Tensor:
|
||
frames = frames[0][0] # [1,4,T]
|
||
return self.codec.decode(frames)
|
||
|
||
|
||
|
||
def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1):
|
||
# Load and pre-process the audio waveform
|
||
if offset != -1 and num_frames!=-1:
|
||
wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames)
|
||
else:
|
||
wav, sr = torchaudio.load(audio_path)
|
||
wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
|
||
wav = wav.unsqueeze(0)
|
||
|
||
# Extract discrete codes from EnCodec
|
||
with torch.no_grad():
|
||
encoded_frames = tokenizer.encode(wav)
|
||
return encoded_frames
|