mirror of
https://github.com/jasonppy/VoiceCraft.git
synced 2025-06-05 21:49:11 +02:00
init
This commit is contained in:
149
data/tokenizer.py
Normal file
149
data/tokenizer.py
Normal file
@ -0,0 +1,149 @@
|
||||
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/data/tokenizer.py
|
||||
# Copyright 2023 (authors: Feiteng Li)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Dict, List, Optional, Pattern, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
# from lhotse.features import FeatureExtractor
|
||||
# from lhotse.utils import Seconds, compute_num_frames
|
||||
from phonemizer.backend import EspeakBackend
|
||||
from phonemizer.backend.espeak.language_switch import LanguageSwitch
|
||||
from phonemizer.backend.espeak.words_mismatch import WordMismatch
|
||||
from phonemizer.punctuation import Punctuation
|
||||
from phonemizer.separator import Separator
|
||||
|
||||
|
||||
|
||||
class TextTokenizer:
|
||||
"""Phonemize Text."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
language="en-us",
|
||||
backend="espeak",
|
||||
separator=Separator(word="_", syllable="-", phone="|"),
|
||||
preserve_punctuation=True,
|
||||
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
||||
with_stress: bool = False,
|
||||
tie: Union[bool, str] = False,
|
||||
language_switch: LanguageSwitch = "keep-flags",
|
||||
words_mismatch: WordMismatch = "ignore",
|
||||
) -> None:
|
||||
phonemizer = EspeakBackend(
|
||||
language,
|
||||
punctuation_marks=punctuation_marks,
|
||||
preserve_punctuation=preserve_punctuation,
|
||||
with_stress=with_stress,
|
||||
tie=tie,
|
||||
language_switch=language_switch,
|
||||
words_mismatch=words_mismatch,
|
||||
)
|
||||
|
||||
self.backend = phonemizer
|
||||
self.separator = separator
|
||||
|
||||
def to_list(self, phonemized: str) -> List[str]:
|
||||
fields = []
|
||||
for word in phonemized.split(self.separator.word):
|
||||
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
|
||||
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
|
||||
fields.extend(
|
||||
[p for p in pp if p != self.separator.phone]
|
||||
+ [self.separator.word]
|
||||
)
|
||||
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
|
||||
self.separator.phone
|
||||
)
|
||||
return fields[:-1]
|
||||
|
||||
def __call__(self, text, strip=True) -> List[List[str]]:
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
|
||||
phonemized = self.backend.phonemize(
|
||||
text, separator=self.separator, strip=strip, njobs=1
|
||||
)
|
||||
return [self.to_list(p) for p in phonemized]
|
||||
|
||||
|
||||
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
|
||||
phonemes = tokenizer([text.strip()])
|
||||
return phonemes[0] # k2symbols
|
||||
|
||||
def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int):
|
||||
assert wav.shape[0] in [1, 2], "Audio must be mono or stereo."
|
||||
if target_channels == 1:
|
||||
wav = wav.mean(0, keepdim=True)
|
||||
elif target_channels == 2:
|
||||
*shape, _, length = wav.shape
|
||||
wav = wav.expand(*shape, target_channels, length)
|
||||
elif wav.shape[0] == 1:
|
||||
wav = wav.expand(target_channels, -1)
|
||||
wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
|
||||
return wav
|
||||
|
||||
class AudioTokenizer:
|
||||
"""EnCodec audio."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Any = None,
|
||||
signature = None
|
||||
) -> None:
|
||||
from audiocraft.solvers import CompressionSolver
|
||||
model = CompressionSolver.model_from_checkpoint(signature)
|
||||
self.sample_rate = model.sample_rate
|
||||
self.channels = model.channels
|
||||
|
||||
if not device:
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
self._device = device
|
||||
|
||||
self.codec = model.to(device)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device
|
||||
|
||||
def encode(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
codes = self.codec.encode(wav.to(self.device))
|
||||
return [(codes[0], None)]
|
||||
|
||||
def decode(self, frames: torch.Tensor) -> torch.Tensor:
|
||||
frames = frames[0][0] # [1,4,T]
|
||||
return self.codec.decode(frames)
|
||||
|
||||
|
||||
|
||||
def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1):
|
||||
# Load and pre-process the audio waveform
|
||||
if offset != -1 and num_frames!=-1:
|
||||
wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames)
|
||||
else:
|
||||
wav, sr = torchaudio.load(audio_path)
|
||||
wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
|
||||
wav = wav.unsqueeze(0)
|
||||
|
||||
# Extract discrete codes from EnCodec
|
||||
with torch.no_grad():
|
||||
encoded_frames = tokenizer.encode(wav)
|
||||
return encoded_frames
|
Reference in New Issue
Block a user