150 lines
5.0 KiB
Python
150 lines
5.0 KiB
Python
|
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/data/tokenizer.py
|
|||
|
# Copyright 2023 (authors: Feiteng Li)
|
|||
|
#
|
|||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|||
|
# you may not use this file except in compliance with the License.
|
|||
|
# You may obtain a copy of the License at
|
|||
|
#
|
|||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|||
|
#
|
|||
|
# Unless required by applicable law or agreed to in writing, software
|
|||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
|
# See the License for the specific language governing permissions and
|
|||
|
# limitations under the License.
|
|||
|
|
|||
|
import re
|
|||
|
from dataclasses import asdict, dataclass
|
|||
|
from typing import Any, Dict, List, Optional, Pattern, Union
|
|||
|
|
|||
|
import numpy as np
|
|||
|
import torch
|
|||
|
import torchaudio
|
|||
|
# from lhotse.features import FeatureExtractor
|
|||
|
# from lhotse.utils import Seconds, compute_num_frames
|
|||
|
from phonemizer.backend import EspeakBackend
|
|||
|
from phonemizer.backend.espeak.language_switch import LanguageSwitch
|
|||
|
from phonemizer.backend.espeak.words_mismatch import WordMismatch
|
|||
|
from phonemizer.punctuation import Punctuation
|
|||
|
from phonemizer.separator import Separator
|
|||
|
|
|||
|
|
|||
|
|
|||
|
class TextTokenizer:
|
|||
|
"""Phonemize Text."""
|
|||
|
|
|||
|
def __init__(
|
|||
|
self,
|
|||
|
language="en-us",
|
|||
|
backend="espeak",
|
|||
|
separator=Separator(word="_", syllable="-", phone="|"),
|
|||
|
preserve_punctuation=True,
|
|||
|
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
|||
|
with_stress: bool = False,
|
|||
|
tie: Union[bool, str] = False,
|
|||
|
language_switch: LanguageSwitch = "keep-flags",
|
|||
|
words_mismatch: WordMismatch = "ignore",
|
|||
|
) -> None:
|
|||
|
phonemizer = EspeakBackend(
|
|||
|
language,
|
|||
|
punctuation_marks=punctuation_marks,
|
|||
|
preserve_punctuation=preserve_punctuation,
|
|||
|
with_stress=with_stress,
|
|||
|
tie=tie,
|
|||
|
language_switch=language_switch,
|
|||
|
words_mismatch=words_mismatch,
|
|||
|
)
|
|||
|
|
|||
|
self.backend = phonemizer
|
|||
|
self.separator = separator
|
|||
|
|
|||
|
def to_list(self, phonemized: str) -> List[str]:
|
|||
|
fields = []
|
|||
|
for word in phonemized.split(self.separator.word):
|
|||
|
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
|
|||
|
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
|
|||
|
fields.extend(
|
|||
|
[p for p in pp if p != self.separator.phone]
|
|||
|
+ [self.separator.word]
|
|||
|
)
|
|||
|
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
|
|||
|
self.separator.phone
|
|||
|
)
|
|||
|
return fields[:-1]
|
|||
|
|
|||
|
def __call__(self, text, strip=True) -> List[List[str]]:
|
|||
|
if isinstance(text, str):
|
|||
|
text = [text]
|
|||
|
|
|||
|
phonemized = self.backend.phonemize(
|
|||
|
text, separator=self.separator, strip=strip, njobs=1
|
|||
|
)
|
|||
|
return [self.to_list(p) for p in phonemized]
|
|||
|
|
|||
|
|
|||
|
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
|
|||
|
phonemes = tokenizer([text.strip()])
|
|||
|
return phonemes[0] # k2symbols
|
|||
|
|
|||
|
def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int):
|
|||
|
assert wav.shape[0] in [1, 2], "Audio must be mono or stereo."
|
|||
|
if target_channels == 1:
|
|||
|
wav = wav.mean(0, keepdim=True)
|
|||
|
elif target_channels == 2:
|
|||
|
*shape, _, length = wav.shape
|
|||
|
wav = wav.expand(*shape, target_channels, length)
|
|||
|
elif wav.shape[0] == 1:
|
|||
|
wav = wav.expand(target_channels, -1)
|
|||
|
wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
|
|||
|
return wav
|
|||
|
|
|||
|
class AudioTokenizer:
|
|||
|
"""EnCodec audio."""
|
|||
|
|
|||
|
def __init__(
|
|||
|
self,
|
|||
|
device: Any = None,
|
|||
|
signature = None
|
|||
|
) -> None:
|
|||
|
from audiocraft.solvers import CompressionSolver
|
|||
|
model = CompressionSolver.model_from_checkpoint(signature)
|
|||
|
self.sample_rate = model.sample_rate
|
|||
|
self.channels = model.channels
|
|||
|
|
|||
|
if not device:
|
|||
|
device = torch.device("cpu")
|
|||
|
if torch.cuda.is_available():
|
|||
|
device = torch.device("cuda:0")
|
|||
|
|
|||
|
self._device = device
|
|||
|
|
|||
|
self.codec = model.to(device)
|
|||
|
|
|||
|
@property
|
|||
|
def device(self):
|
|||
|
return self._device
|
|||
|
|
|||
|
def encode(self, wav: torch.Tensor) -> torch.Tensor:
|
|||
|
codes = self.codec.encode(wav.to(self.device))
|
|||
|
return [(codes[0], None)]
|
|||
|
|
|||
|
def decode(self, frames: torch.Tensor) -> torch.Tensor:
|
|||
|
frames = frames[0][0] # [1,4,T]
|
|||
|
return self.codec.decode(frames)
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1):
|
|||
|
# Load and pre-process the audio waveform
|
|||
|
if offset != -1 and num_frames!=-1:
|
|||
|
wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames)
|
|||
|
else:
|
|||
|
wav, sr = torchaudio.load(audio_path)
|
|||
|
wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
|
|||
|
wav = wav.unsqueeze(0)
|
|||
|
|
|||
|
# Extract discrete codes from EnCodec
|
|||
|
with torch.no_grad():
|
|||
|
encoded_frames = tokenizer.encode(wav)
|
|||
|
return encoded_frames
|