This commit is contained in:
jason-on-salt-a40
2024-03-21 11:02:20 -07:00
commit 6760f29bd0
32 changed files with 9321 additions and 0 deletions

View File

@ -0,0 +1,538 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import namedtuple
from dataclasses import dataclass
from functools import lru_cache
import logging
import typing as tp
from abc import ABC, abstractmethod
import torch
LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
@dataclass
class Pattern:
"""Base implementation of a pattern over a sequence with multiple codebooks.
The codebook pattern consists in a layout, defining for each sequence step
the list of coordinates of each codebook timestep in the resulting interleaved sequence.
The first item of the pattern is always an empty list in order to properly insert a special token
to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
and ``timesteps`` the number of timesteps corresponding to the original sequence.
The pattern provides convenient methods to build and revert interleaved sequences from it:
``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size,
K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
for the output sequence. The unfilled positions are replaced with a special token and the built sequence
is returned along with a mask indicating valid tokens.
``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
to fill and specify invalid positions if needed.
See the dedicated methods for more details.
"""
# Pattern layout, for each sequence step, we have a list of coordinates
# corresponding to the original codebook timestep and position.
# The first list is always an empty list in order to properly insert
# a special token to start with.
layout: PatternLayout
timesteps: int
n_q: int
def __post_init__(self):
assert len(self.layout) > 0
assert self.layout[0] == []
self._validate_layout()
self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
# logging.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
def _validate_layout(self):
"""Runs checks on the layout to ensure a valid pattern is defined.
A pattern is considered invalid if:
- Multiple timesteps for a same codebook are defined in the same sequence step
- The timesteps for a given codebook are not in ascending order as we advance in the sequence
(this would mean that we have future timesteps before past timesteps).
"""
q_timesteps = {q: 0 for q in range(self.n_q)}
for s, seq_coords in enumerate(self.layout):
if len(seq_coords) > 0:
qs = set()
for coord in seq_coords:
qs.add(coord.q)
last_q_timestep = q_timesteps[coord.q]
assert coord.t >= last_q_timestep, \
f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
q_timesteps[coord.q] = coord.t
# each sequence step contains at max 1 coordinate per codebook
assert len(qs) == len(seq_coords), \
f"Multiple entries for a same codebook are found at step {s}"
@property
def num_sequence_steps(self):
return len(self.layout) - 1
@property
def max_delay(self):
max_t_in_seq_coords = 0
for seq_coords in self.layout[1:]:
for coords in seq_coords:
max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
return max_t_in_seq_coords - self.timesteps
@property
def valid_layout(self):
valid_step = len(self.layout) - self.max_delay
return self.layout[:valid_step]
def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
"""Get codebook coordinates in the layout that corresponds to the specified timestep t
and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
and the actual codebook coordinates.
"""
assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
if q is not None:
assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
coords = []
for s, seq_codes in enumerate(self.layout):
for code in seq_codes:
if code.t == t and (q is None or code.q == q):
coords.append((s, code))
return coords
def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
steps_with_timesteps = self.get_steps_with_timestep(t, q)
return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
device: tp.Union[torch.device, str] = 'cpu'):
"""Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
Args:
timesteps (int): Maximum number of timesteps steps to consider.
keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
device (Union[torch.device, str]): Device for created tensors.
Returns:
indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
"""
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
# use the proper layout based on whether we limit ourselves to valid steps only or not,
# note that using the valid_layout will result in a truncated sequence up to the valid steps
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
# fill indexes with last sequence step value that will correspond to our special token
# the last value is n_q * timesteps as we have flattened z and append special token as the last token
# which will correspond to the index: n_q * timesteps
indexes[:] = n_q * timesteps
# iterate over the pattern and fill scattered indexes and mask
for s, sequence_coords in enumerate(ref_layout):
for coords in sequence_coords:
if coords.t < timesteps:
indexes[coords.q, s] = coords.t + coords.q * timesteps
mask[coords.q, s] = 1
indexes = torch.from_numpy(indexes).to(device)
mask = torch.from_numpy(mask).to(device)
return indexes, mask
def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
"""Build sequence corresponding to the pattern from the input tensor z.
The sequence is built using up to sequence_steps if specified, and non-pattern
coordinates are filled with the special token.
Args:
z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
Steps that are beyond valid steps will be replaced by the special_token in that case.
Returns:
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
"""
B, K, T = z.shape
indexes, mask = self._build_pattern_sequence_scatter_indexes(
T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
)
z = z.view(B, -1)
# we append the special token as the last index of our flattened z tensor
z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
values = z[:, indexes.view(-1)]
values = values.view(B, K, indexes.shape[-1])
return values, indexes, mask
def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
keep_only_valid_steps: bool = False,
is_model_output: bool = False,
device: tp.Union[torch.device, str] = 'cpu'):
"""Builds scatter indexes required to retrieve the original multi-codebook sequence
from interleaving pattern.
Args:
sequence_steps (int): Sequence steps.
n_q (int): Number of codebooks.
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
Steps that are beyond valid steps will be replaced by the special_token in that case.
is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
device (Union[torch.device, str]): Device for created tensors.
Returns:
torch.Tensor: Indexes for reconstructing the output, of shape [K, T].
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
"""
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
# TODO(jade): Do we want to further truncate to only valid timesteps here as well?
timesteps = self.timesteps
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
assert sequence_steps <= len(ref_layout), \
f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
# ensure we take the appropriate indexes to keep the model output from the first special token as well
if is_model_output:
ref_layout = ref_layout[1:]
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
# fill indexes with last sequence step value that will correspond to our special token
indexes[:] = n_q * sequence_steps
for s, sequence_codes in enumerate(ref_layout):
if s < sequence_steps:
for code in sequence_codes:
if code.t < timesteps:
indexes[code.q, code.t] = s + code.q * sequence_steps
mask[code.q, code.t] = 1
indexes = torch.from_numpy(indexes).to(device)
mask = torch.from_numpy(mask).to(device)
return indexes, mask
def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
"""Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
are filled with the special token.
Args:
s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
Returns:
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
"""
B, K, S = s.shape
indexes, mask = self._build_reverted_sequence_scatter_indexes(
S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
)
s = s.view(B, -1)
# we append the special token as the last index of our flattened z tensor
s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
values = s[:, indexes.view(-1)]
values = values.view(B, K, indexes.shape[-1])
return values, indexes, mask
def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
"""Revert model logits obtained on a sequence built from the pattern
back to a tensor matching the original sequence.
This method is similar to ``revert_pattern_sequence`` with the following specificities:
1. It is designed to work with the extra cardinality dimension
2. We return the logits for the first sequence item that matches the special_token and
which matching target in the original sequence is the first item of the sequence,
while we skip the last logits as there is no matching target
"""
B, card, K, S = logits.shape
indexes, mask = self._build_reverted_sequence_scatter_indexes(
S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
)
logits = logits.reshape(B, card, -1)
# we append the special token as the last index of our flattened z tensor
logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
values = logits[:, :, indexes.view(-1)]
values = values.view(B, card, K, indexes.shape[-1])
return values, indexes, mask
class CodebooksPatternProvider(ABC):
"""Abstraction around providing pattern for interleaving codebooks.
The CodebooksPatternProvider abstraction allows to implement various strategies to
define interleaving pattern of sequences composed of multiple codebooks. For a given
number of codebooks `n_q`, the pattern provider can generate a specified pattern
corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
can be used to construct a new sequence from the original codes respecting the specified
pattern. The pattern is defined as a list of list of code coordinates, code coordinate
being a tuple with the original timestep and codebook to build the new sequence.
Note that all patterns must start with an empty list that is then used to insert a first
sequence step of special tokens in the newly generated sequence.
Args:
n_q (int): number of codebooks.
cached (bool): if True, patterns for a given length are cached. In general
that should be true for efficiency reason to avoid synchronization points.
"""
def __init__(self, n_q: int, cached: bool = True):
assert n_q > 0
self.n_q = n_q
self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
@abstractmethod
def get_pattern(self, timesteps: int) -> Pattern:
"""Builds pattern with specific interleaving between codebooks.
Args:
timesteps (int): Total numer of timesteps.
"""
raise NotImplementedError()
class DelayedPatternProvider(CodebooksPatternProvider):
"""Provider for delayed pattern across delayed codebooks.
Codebooks are delayed in the sequence and sequence steps will contain codebooks
from different timesteps.
Example:
Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
[[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]]
The resulting sequence obtained from the returned pattern is:
[[S, 1, 2, 3, 4],
[S, S, 1, 2, 3],
[S, S, S, 1, 2]]
(with S being a special token)
Args:
n_q (int): Number of codebooks.
delays (Optional[List[int]]): Delay for each of the codebooks.
If delays not defined, each codebook is delayed by 1 compared to the previous one.
flatten_first (int): Flatten the first N timesteps.
empty_initial (int): Prepend with N empty list of coordinates.
"""
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
flatten_first: int = 0, empty_initial: int = 0):
super().__init__(n_q)
if delays is None:
delays = list(range(n_q))
self.delays = delays
self.flatten_first = flatten_first
self.empty_initial = empty_initial
assert len(self.delays) == self.n_q
assert sorted(self.delays) == self.delays
def get_pattern(self, timesteps: int) -> Pattern:
out: PatternLayout = [[]]
max_delay = max(self.delays)
if self.empty_initial:
out += [[] for _ in range(self.empty_initial)]
if self.flatten_first:
for t in range(min(timesteps, self.flatten_first)):
for q in range(self.n_q):
out.append([LayoutCoord(t, q)])
for t in range(self.flatten_first, timesteps + max_delay):
v = []
for q, delay in enumerate(self.delays):
t_for_q = t - delay
if t_for_q >= self.flatten_first:
v.append(LayoutCoord(t_for_q, q))
out.append(v)
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
class ParallelPatternProvider(DelayedPatternProvider):
"""Provider for parallel pattern across codebooks.
This pattern provider is a special case of the delayed pattern with actually no delay,
hence delays=repeat(0, n_q).
Args:
n_q (int): Number of codebooks.
"""
def __init__(self, n_q: int):
super().__init__(n_q, [0] * n_q)
class UnrolledPatternProvider(CodebooksPatternProvider):
"""Provider for unrolling codebooks pattern.
This pattern provider enables to represent the codebook flattened completely or only to some extend
while also specifying a given delay between the flattened codebooks representation, allowing to
unroll the codebooks in the sequence.
Example:
1. Flattening of the codebooks.
By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
taking n_q = 3 and timesteps = 4:
[[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]]
will result into:
[[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
[[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]]
will result into:
[[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
and delays = [0, 3, 3]:
[[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]]
will result into:
[[S, S, S, 1, S, 2, S, 3, S, 4],
[S, S, S, 1, S, 2, S, 3, S, 4],
[1, 2, 3, S, 4, S, 5, S, 6, S]]
Args:
n_q (int): Number of codebooks.
flattening (Optional[List[int]]): Flattening schema over the codebooks. If not defined,
the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
have n_q extra steps for each timestep.
delays (Optional[List[int]]): Delay for each of the codebooks. If not defined,
no delay is added and therefore will default to [0] * ``n_q``.
Note that two codebooks that will be flattened to the same inner step
should have the same delay, otherwise the pattern is considered as invalid.
"""
FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
delays: tp.Optional[tp.List[int]] = None):
super().__init__(n_q)
if flattening is None:
flattening = list(range(n_q))
if delays is None:
delays = [0] * n_q
assert len(flattening) == n_q
assert len(delays) == n_q
assert sorted(flattening) == flattening
assert sorted(delays) == delays
self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
self.max_delay = max(delays)
def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
"""Build a flattened codebooks representation as a dictionary of inner step
and the actual codebook indices corresponding to the flattened codebook. For convenience, we
also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
"""
flattened_codebooks: dict = {}
for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
if inner_step not in flattened_codebooks:
flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
else:
flat_codebook = flattened_codebooks[inner_step]
assert flat_codebook.delay == delay, (
"Delay and flattening between codebooks is inconsistent: ",
"two codebooks flattened to the same position should have the same delay."
)
flat_codebook.codebooks.append(q)
flattened_codebooks[inner_step] = flat_codebook
return flattened_codebooks
@property
def _num_inner_steps(self):
"""Number of inner steps to unroll between timesteps in order to flatten the codebooks.
"""
return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
def num_virtual_steps(self, timesteps: int) -> int:
return timesteps * self._num_inner_steps + 1
def get_pattern(self, timesteps: int) -> Pattern:
"""Builds pattern for delay across codebooks.
Args:
timesteps (int): Total numer of timesteps.
"""
# the PatternLayout is built as a tuple of sequence position and list of coordinates
# so that it can be reordered properly given the required delay between codebooks of given timesteps
indexed_out: list = [(-1, [])]
max_timesteps = timesteps + self.max_delay
for t in range(max_timesteps):
# for each timestep, we unroll the flattened codebooks,
# emitting the sequence step with the corresponding delay
for step in range(self._num_inner_steps):
if step in self._flattened_codebooks:
# we have codebooks at this virtual step to emit
step_codebooks = self._flattened_codebooks[step]
t_for_q = t + step_codebooks.delay
coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
if t_for_q < max_timesteps and t < max_timesteps:
indexed_out.append((t_for_q, coords))
else:
# there is no codebook in this virtual step so we emit an empty list
indexed_out.append((t, []))
out = [coords for _, coords in sorted(indexed_out)]
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
class VALLEPattern(CodebooksPatternProvider):
"""Almost VALL-E style pattern. We futher allow some delays for the
codebooks other than the first one.
Args:
n_q (int): Number of codebooks.
delays (Optional[List[int]]): Delay for each of the codebooks.
If delays not defined, each codebook is delayed by 1 compared to the previous one.
"""
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
super().__init__(n_q)
if delays is None:
delays = [0] * (n_q - 1)
self.delays = delays
assert len(self.delays) == self.n_q - 1
assert sorted(self.delays) == self.delays
def get_pattern(self, timesteps: int) -> Pattern:
out: PatternLayout = [[]]
for t in range(timesteps):
out.append([LayoutCoord(t, 0)])
max_delay = max(self.delays)
for t in range(timesteps + max_delay):
v = []
for q, delay in enumerate(self.delays):
t_for_q = t - delay
if t_for_q >= 0:
v.append(LayoutCoord(t_for_q, q + 1))
out.append(v)
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
class MusicLMPattern(CodebooksPatternProvider):
"""Almost MusicLM style pattern. This is equivalent to full flattening
but in a different order.
Args:
n_q (int): Number of codebooks.
group_by (int): Number of codebooks to group together.
"""
def __init__(self, n_q: int, group_by: int = 2):
super().__init__(n_q)
self.group_by = group_by
def get_pattern(self, timesteps: int) -> Pattern:
out: PatternLayout = [[]]
for offset in range(0, self.n_q, self.group_by):
for t in range(timesteps):
for q in range(offset, offset + self.group_by):
out.append([LayoutCoord(t, q)])
return Pattern(out, n_q=self.n_q, timesteps=timesteps)

View File

View File

@ -0,0 +1,653 @@
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py, modified by Puyuan Peng, 2024
from typing import Optional, Tuple
import torch
from torch import Tensor
from torch.nn import Linear, Module
from torch.nn import functional as F
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter
import logging
from typing import Callable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from torch.types import _dtype as DType
else:
# The JIT doesn't understand Union, nor torch.dtype here
DType = int
def _canonical_mask(
mask: Optional[Tensor],
mask_name: str,
other_type: Optional[DType],
other_name: str,
target_type: DType,
check_other: bool = True,
) -> Optional[Tensor]:
if mask is not None:
_mask_dtype = mask.dtype
_mask_is_float = torch.is_floating_point(mask)
if _mask_dtype != torch.bool and not _mask_is_float:
raise AssertionError(
f"only bool and floating types of {mask_name} are supported")
if check_other and other_type is not None:
if _mask_dtype != other_type:
warnings.warn(
f"Support for mismatched {mask_name} and {other_name} "
"is deprecated. Use same type for both instead."
)
if not _mask_is_float:
mask = (
torch.zeros_like(mask, dtype=target_type)
.masked_fill_(mask, float("-inf"))
)
return mask
def _in_projection_packed(
q: Tensor,
k: Tensor,
v: Tensor,
w: Tensor,
b: Optional[Tensor] = None,
) -> List[Tensor]:
r"""
Performs the in-projection step of the attention operation, using packed weights.
Output is a triple containing projection tensors for query, key and value.
Args:
q, k, v: query, key and value tensors to be projected. For self-attention,
these are typically the same tensor; for encoder-decoder attention,
k and v are typically the same tensor. (We take advantage of these
identities for performance if they are present.) Regardless, q, k and v
must share a common embedding dimension; otherwise their shapes may vary.
w: projection weights for q, k and v, packed into a single tensor. Weights
are packed along dimension 0, in q, k, v order.
b: optional projection biases for q, k and v, packed into a single tensor
in q, k, v order.
Shape:
Inputs:
- q: :math:`(..., E)` where E is the embedding dimension
- k: :math:`(..., E)` where E is the embedding dimension
- v: :math:`(..., E)` where E is the embedding dimension
- w: :math:`(E * 3, E)` where E is the embedding dimension
- b: :math:`E * 3` where E is the embedding dimension
Output:
- in output list :math:`[q', k', v']`, each output tensor will have the
same shape as the corresponding input tensor.
"""
E = q.size(-1)
if k is v:
if q is k:
# self-attention
proj = F.linear(q, w, b)
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
return proj[0], proj[1], proj[2]
else:
# encoder-decoder attention
w_q, w_kv = w.split([E, E * 2])
if b is None:
b_q = b_kv = None
else:
b_q, b_kv = b.split([E, E * 2])
q_proj = F.linear(q, w_q, b_q)
kv_proj = F.linear(k, w_kv, b_kv)
# reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
return (q_proj, kv_proj[0], kv_proj[1])
else:
w_q, w_k, w_v = w.chunk(3)
if b is None:
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = b.chunk(3)
return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
if input is None:
return None
elif isinstance(input, torch.Tensor):
return input.dtype
raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
class MultiheadAttention(Module):
r"""Allows the model to jointly attend to information
from different representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
Multi-Head Attention is defined as:
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
``forward()`` will use a special optimized implementation if all of the following
conditions are met:
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
restriction will be loosened in the future.)
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
- training is disabled (using ``.eval()``)
- dropout is 0
- ``add_bias_kv`` is ``False``
- ``add_zero_attn`` is ``False``
- ``batch_first`` is ``True`` and the input is batched
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
- at most one of ``key_padding_mask`` or ``attn_mask`` is passed
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
nor ``attn_mask`` is passed
If the optimized implementation is in use, a
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
``query``/``key``/``value`` to represent padding more efficiently than using a
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
will be returned, and an additional speedup proportional to the fraction of the input
that is padding can be expected.
Args:
embed_dim: Total dimension of the model.
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
Default: ``False``.
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
Examples::
>>> # xdoctest: +SKIP
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
__constants__ = ["batch_first"]
bias_k: Optional[torch.Tensor]
bias_v: Optional[torch.Tensor]
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
linear1_cls=Linear,
linear2_cls=Linear,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = (
self.kdim == embed_dim and self.vdim == embed_dim
)
self.num_heads = num_heads
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
if add_bias_kv:
self.bias_k = Parameter(
torch.empty((1, 1, embed_dim), **factory_kwargs)
)
self.bias_v = Parameter(
torch.empty((1, 1, embed_dim), **factory_kwargs)
)
else:
self.bias_k = self.bias_v = None
if linear1_cls == Linear:
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs)
)
self.k_proj_weight = Parameter(
torch.empty((embed_dim, self.kdim), **factory_kwargs)
)
self.v_proj_weight = Parameter(
torch.empty((embed_dim, self.vdim), **factory_kwargs)
)
self.register_parameter("in_proj_weight", None)
else:
# go down this route with voicecraft
self.in_proj_weight = Parameter(
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
)
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None)
if bias: # True by default
self.in_proj_bias = Parameter(
torch.empty(3 * embed_dim, **factory_kwargs)
)
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = NonDynamicallyQuantizableLinear(
embed_dim, embed_dim, bias=bias, **factory_kwargs
)
self._reset_parameters()
else:
if not self._qkv_same_embed_dim:
raise NotImplementedError
else:
self.in_proj_linear = linear1_cls(
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
)
self.in_proj_weight = self.in_proj_linear.weight
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None)
if bias:
self.in_proj_bias = self.in_proj_linear.bias
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = linear2_cls(
embed_dim, embed_dim, bias=bias, **factory_kwargs
)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
if self.bias_v is not None:
xavier_normal_(self.bias_v)
self.add_zero_attn = add_zero_attn
def _reset_parameters(self):
if self._qkv_same_embed_dim:
xavier_uniform_(self.in_proj_weight)
else:
xavier_uniform_(self.q_proj_weight)
xavier_uniform_(self.k_proj_weight)
xavier_uniform_(self.v_proj_weight)
if self.in_proj_bias is not None:
constant_(self.in_proj_bias, 0.0)
constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
if self.bias_v is not None:
xavier_normal_(self.bias_v)
def __setstate__(self, state):
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
if "_qkv_same_embed_dim" not in state:
state["_qkv_same_embed_dim"] = True
super(MultiheadAttention, self).__setstate__(state)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True,
past: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
Queries are compared against key-value pairs to produce the output.
See "Attention Is All You Need" for more details.
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
See "Attention Is All You Need" for more details.
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
See "Attention Is All You Need" for more details.
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
Binary and byte masks are supported.
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
Default: ``True``.
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
the attention weight.
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
Outputs:
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
embedding dimension ``embed_dim``.
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
.. note::
`batch_first` argument is ignored for unbatched inputs.
"""
is_batched = query.dim() == 3
if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype
if _kpm_dtype != torch.bool and not torch.is_floating_point(
key_padding_mask
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
why_not_fast_path = ""
if not is_batched:
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
elif query is not key or key is not value:
# When lifting this restriction, don't forget to either
# enforce that the dtypes all match or test cases where
# they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif (
self.in_proj_bias is not None
and query.dtype != self.in_proj_bias.dtype
):
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
elif (
self.in_proj_weight is not None
and query.dtype != self.in_proj_weight.dtype
):
# this case will fail anyway, but at least they'll get a useful error message.
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
elif self.training:
why_not_fast_path = "training is enabled"
elif not self.batch_first:
why_not_fast_path = "batch_first was not True"
elif self.bias_k is not None:
why_not_fast_path = "self.bias_k was not None"
elif self.bias_v is not None:
why_not_fast_path = "self.bias_v was not None"
elif self.dropout:
why_not_fast_path = f"dropout was {self.dropout}, required zero"
elif self.add_zero_attn:
why_not_fast_path = "add_zero_attn was enabled"
elif not self._qkv_same_embed_dim:
why_not_fast_path = "_qkv_same_embed_dim was not True"
elif attn_mask is not None:
why_not_fast_path = "attn_mask was not None"
elif query.is_nested and key_padding_mask is not None:
why_not_fast_path = (
"key_padding_mask is not supported with NestedTensor input"
)
elif self.num_heads % 2 == 1:
why_not_fast_path = "num_heads is odd"
elif torch.is_autocast_enabled():
why_not_fast_path = "autocast is enabled"
if not why_not_fast_path:
tensor_args = (
query,
key,
value,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
)
# We have to use list comprehensions below because TorchScript does not support
# generator expressions.
if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function"
elif not all(
[
(x is None or x.is_cuda or "cpu" in str(x.device))
for x in tensor_args
]
):
why_not_fast_path = (
"some Tensor argument is neither CUDA nor CPU"
)
elif torch.is_grad_enabled() and any(
[x is not None and x.requires_grad for x in tensor_args]
):
why_not_fast_path = (
"grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad"
)
if not why_not_fast_path:
return torch._native_multi_head_attention(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
key_padding_mask
if key_padding_mask is not None
else attn_mask,
need_weights,
average_attn_weights,
1
if key_padding_mask is not None
else 0
if attn_mask is not None
else None,
)
any_nested = query.is_nested or key.is_nested or value.is_nested
assert not any_nested, (
"MultiheadAttention does not support NestedTensor outside of its fast path. "
+ f"The fast path was not hit because {why_not_fast_path}"
)
if self.batch_first and is_batched:
# make sure that the transpose op does not affect the "is" property
if key is value:
if query is key:
query = key = value = query.transpose(1, 0)
else:
query, key = [x.transpose(1, 0) for x in (query, key)]
value = key
else:
query, key, value = [
x.transpose(1, 0) for x in (query, key, value)
]
if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
average_attn_weights=average_attn_weights,
)
else:
# re-write the self.attention here, to get k, v cache
tgt_len, bsz, embed_dim = query.shape
src_len, _, _ = key.shape
num_heads = self.num_heads
key_padding_mask = _canonical_mask(
mask=key_padding_mask,
mask_name="key_padding_mask",
other_type=_none_or_dtype(attn_mask),
other_name="attn_mask",
target_type=query.dtype
)
attn_mask = _canonical_mask(
mask=attn_mask,
mask_name="attn_mask",
other_type=None,
other_name="",
target_type=query.dtype,
check_other=False,
)
head_dim = self.embed_dim // self.num_heads
assert head_dim * self.num_heads == self.embed_dim, f"embed_dim {self.embed_dim} not divisible by num_heads {self.num_heads}"
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
# k_present, v_present = k, v
#
# reshape q, k, v for multihead attention and make em batch first
#
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) # (bsz * num_heads, src_len, head_dim)
src_len = k.size(1)
if past is not None and past.ndim > 2:
expected_src_len = src_len + past[0].shape[-2]
else:
expected_src_len = src_len
# ensure attn_mask's dim is 3
if attn_mask.dim() == 2:
correct_2d_size = (tgt_len, expected_src_len)
if attn_mask.shape != correct_2d_size:
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3:
correct_3d_size = (bsz * num_heads, tgt_len, expected_src_len)
if attn_mask.shape != correct_3d_size:
raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
if key_padding_mask is not None:
assert key_padding_mask.shape == (bsz, expected_src_len), \
f"expecting key_padding_mask shape of {(bsz, expected_src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = key_padding_mask.view(bsz, 1, 1, expected_src_len). \
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, expected_src_len)
if attn_mask is None:
attn_mask = key_padding_mask
else:
attn_mask = attn_mask + key_padding_mask
if not self.training:
dropout_p = 0.0
else:
dropout_p = self.dropout
if need_weights:
raise NotImplementedError("need_weights not implemented for voicecraft")
# B, Nt, E = q.shape
# q_scaled = q / math.sqrt(E)
# assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
# if attn_mask is not None:
# attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
# else:
# attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
# attn_output_weights = softmax(attn_output_weights, dim=-1)
# if dropout_p > 0.0:
# attn_output_weights = dropout(attn_output_weights, p=dropout_p)
# attn_output = torch.bmm(attn_output_weights, v)
# attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
# attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
# attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
# # optionally average attention weights over heads
# attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
# if average_attn_weights:
# attn_output_weights = attn_output_weights.mean(dim=1)
# if not is_batched:
# # squeeze the output if input was unbatched
# attn_output = attn_output.squeeze(1)
# attn_output_weights = attn_output_weights.squeeze(0)
# return attn_output, attn_output_weights
else:
# attn_mask can be either (L,S) or (N*num_heads, L, S)
# if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
# in order to match the input for SDPA of (N, num_heads, L, S)
if attn_mask is not None:
if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
attn_mask = attn_mask.unsqueeze(0)
else:
attn_mask = attn_mask.view(bsz, num_heads, -1, expected_src_len)
q = q.view(bsz, num_heads, tgt_len, head_dim)
k = k.view(bsz, num_heads, src_len, head_dim)
v = v.view(bsz, num_heads, src_len, head_dim)
# logging.info(f"shape of past: {past.shape}")
if past is not None:
present = torch.stack([k, v], dim=0) # (2, bsz, num_heads, src_len, head_dim)
if past.ndim > 2: # this means we use kvcache, otherwise we just pass in a placeholder, but not actually using kvcache
pk, pv = past
k = torch.cat([pk, k], dim=-2)
v = torch.cat([pv, v], dim=-2)
else:
present = None
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal=False)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
attn_output = F.linear(attn_output, self.out_proj.weight, self.out_proj.bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
if not is_batched:
# squeeze the output if input was unbatched
attn_output = attn_output.squeeze(1)
# if self.training:
# return attn_output, None
# else:
# return (attn_output, present), None
# harded coded, the code do not support returning attn weigths yet
attn_output_weights=None
if self.batch_first and is_batched:
return attn_output.transpose(1, 0), present
else:
return attn_output, present

View File

@ -0,0 +1,98 @@
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
# Copyright 2023 (authors: Feiteng Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn as nn
class TokenEmbedding(nn.Module):
def __init__(
self,
dim_model: int,
vocab_size: int,
dropout: float = 0.0,
):
super().__init__()
self.vocab_size = vocab_size
self.dim_model = dim_model
self.dropout = torch.nn.Dropout(p=dropout)
self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
@property
def weight(self) -> torch.Tensor:
return self.word_embeddings.weight
def embedding(self, index: int) -> torch.Tensor:
return self.word_embeddings.weight[index : index + 1]
def forward(self, x: torch.Tensor):
X = self.word_embeddings(x)
X = self.dropout(X)
return X
class SinePositionalEmbedding(nn.Module):
def __init__(
self,
dim_model: int,
dropout: float = 0.0,
scale: bool = False,
alpha: bool = False,
):
super().__init__()
self.dim_model = dim_model
self.x_scale = math.sqrt(dim_model) if scale else 1.0
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
self.dropout = torch.nn.Dropout(p=dropout)
self.reverse = False
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, 4000))
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.dim_model)
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else:
position = torch.arange(
0, x.size(1), dtype=torch.float32
).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.dim_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.dim_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.extend_pe(x)
output = x.unsqueeze(-1) if x.ndim == 2 else x
output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
return self.dropout(output)

View File

@ -0,0 +1,63 @@
import torch
import torch.nn.functional as F
def top_k_top_p_filtering(
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
top_k = min(
max(top_k, min_tokens_to_keep), logits.size(-1)
) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(
F.softmax(sorted_logits, dim=-1), dim=-1
)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
..., :-1
].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
# temperature: (`optional`) float
# The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
# top_k: (`optional`) int
# The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
# top_p: (`optional`) float
# The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
logits = logits / temperature
# Top-p/top-k filtering
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
# Sample
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
return token

1406
models/modules/scaling.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,698 @@
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py, modified by Puyuan Peng 2024
import copy
import numbers
from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from .activation import MultiheadAttention
from .scaling import ActivationBalancer, BalancedDoubleSwish
from .scaling import BasicNorm as _BasicNorm
_shape_t = Union[int, List[int], torch.Size]
class LayerNorm(nn.Module):
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
def __init__(
self,
normalized_shape: _shape_t,
eps: float = 1e-5,
elementwise_affine: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
self.bias = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
if self.elementwise_affine:
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
return (
F.layer_norm(
input,
self.normalized_shape,
self.weight,
self.bias,
self.eps,
),
embedding,
)
assert embedding is None
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps
)
def extra_repr(self) -> str:
return (
"{normalized_shape}, eps={eps}, "
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
)
class AdaptiveLayerNorm(nn.Module):
r"""Adaptive Layer Normalization"""
def __init__(self, d_model, norm) -> None:
super(AdaptiveLayerNorm, self).__init__()
self.project_layer = nn.Linear(d_model, 2 * d_model)
self.norm = norm
self.d_model = d_model
self.eps = self.norm.eps
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1,
)
return (weight * self.norm(input) + bias, embedding)
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1,
)
return weight * self.norm(input) + bias
class BasicNorm(_BasicNorm):
def __init__(
self,
d_model: int,
eps: float = 1e-5,
device=None,
dtype=None,
):
super(BasicNorm, self).__init__(d_model, eps=eps)
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
return (
super(BasicNorm, self).forward(input),
embedding,
)
assert embedding is None
return super(BasicNorm, self).forward(input)
class BalancedBasicNorm(nn.Module):
def __init__(
self,
d_model: int,
eps: float = 1e-5,
device=None,
dtype=None,
):
super(BalancedBasicNorm, self).__init__()
self.balancer = ActivationBalancer(
d_model,
channel_dim=-1,
min_positive=0.45,
max_positive=0.55,
max_abs=6.0,
)
self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
return self.norm((self.balancer(input), embedding))
assert embedding is None
return self.norm(self.balancer(input))
class IdentityNorm(nn.Module):
def __init__(
self,
d_model: int,
eps: float = 1e-5,
device=None,
dtype=None,
) -> None:
super(IdentityNorm, self).__init__()
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
return input
assert embedding is None
return input
class TransformerEncoderLayer(nn.Module):
__constants__ = ["batch_first", "norm_first"]
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
batch_first: bool = False,
norm_first: bool = False,
device=None,
dtype=None,
linear1_self_attention_cls: nn.Module = nn.Linear,
linear2_self_attention_cls: nn.Module = nn.Linear,
linear1_feedforward_cls: nn.Module = nn.Linear,
linear2_feedforward_cls: nn.Module = nn.Linear,
layer_norm_cls: nn.Module = LayerNorm,
layer_norm_eps: float = 1e-5,
adaptive_layer_norm=False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(TransformerEncoderLayer, self).__init__()
self.self_attn = MultiheadAttention(
d_model,
nhead,
dropout=dropout,
batch_first=batch_first,
linear1_cls=linear1_self_attention_cls,
linear2_cls=linear2_self_attention_cls,
**factory_kwargs,
)
# Implementation of Feedforward model
self.linear1 = linear1_feedforward_cls(
d_model, dim_feedforward, **factory_kwargs
)
self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(
dim_feedforward, d_model, **factory_kwargs
)
self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
activation = _get_activation_fn(activation)
elif isinstance(activation, partial):
activation = activation(d_model)
elif activation == BalancedDoubleSwish:
activation = BalancedDoubleSwish(d_model)
# # We can't test self.activation in forward() in TorchScript,
# # so stash some information about it instead.
# if activation is F.relu or isinstance(activation, torch.nn.ReLU):
# self.activation_relu_or_gelu = 1
# elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
# self.activation_relu_or_gelu = 2
# else:
# self.activation_relu_or_gelu = 0
self.activation = activation
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if layer_norm_cls == IdentityNorm:
norm2 = BalancedBasicNorm(
d_model, eps=layer_norm_eps, **factory_kwargs
)
else:
norm2 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
if adaptive_layer_norm:
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
else:
self.norm1 = norm1
self.norm2 = norm2
def __setstate__(self, state):
super(TransformerEncoderLayer, self).__setstate__(state)
if not hasattr(self, "activation"):
self.activation = F.relu
def forward(
self,
src: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
need_weights: Optional[bool] = False,
past: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
x, stage_embedding = src, None
is_src_tuple = False
if isinstance(src, tuple):
x, stage_embedding = src
is_src_tuple = True
if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != torch.bool and not torch.is_floating_point(
src_key_padding_mask
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
if need_weights:
if self.norm_first:
out, attn = self._sa_block_attn(
self.norm1(x, stage_embedding),
src_mask,
src_key_padding_mask,
past
)
out, present = out # present is the kvcache of the present timestep
x = x + out
x = x + self._ff_block(self.norm2(x, stage_embedding))
else:
out, attn = self._sa_block_attn(x, src_mask, src_key_padding_mask, past)
out, present = out # present is the kvcache of the present timestep
x = self.norm1(
x + out,
stage_embedding,
)
x = self.norm2(x + self._ff_block(x), stage_embedding)
assert not is_src_tuple
# return (x, stage_embedding)
return (x, attn)
else:
if self.norm_first:
out = self._sa_block(
self.norm1(x, stage_embedding),
src_mask,
src_key_padding_mask, past
)
out, present = out # present is the kvcache of the present timestep
x = x + out
x = x + self._ff_block(self.norm2(x, stage_embedding))
else:
out = self._sa_block(x, src_mask, src_key_padding_mask)
out, present = out # present is the kvcache of the present timestep
x = self.norm1(
x + out,
stage_embedding, past
)
x = self.norm2(x + self._ff_block(x), stage_embedding)
if is_src_tuple:
x = (x, stage_embedding)
if present != None:
x = [x, present]
return x
# self-attention block
def _sa_block(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
past: Optional[Tensor] = None,
) -> Tensor:
x = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
past=past
)
x, present = x
return self.dropout1(x), present
# self-attention block, also return attention weights
def _sa_block_attn(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
past: Optional[Tensor] = None,
) -> Tensor:
x, attn = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=True,
past=past
)
x, present = x
return (self.dropout1(x), present), attn
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
class TransformerEncoder(nn.Module):
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
Args:
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
enable_nested_tensor: if True, input will automatically convert to nested tensor
(and convert back on output). This will improve the overall performance of
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
Examples::
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
>>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
"""
__constants__ = ["norm"]
def __init__(self, encoder_layer, num_layers, norm=None):
super(TransformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(
self,
src: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
return_layer_states: bool = False,
need_weights:Optional[bool] = False,
past: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
return_layer_states: return layers' state (optional).
Shape:
see the docs in Transformer class.
"""
if return_layer_states:
assert not need_weights
layer_states = [] # layers' output
output = src
for mod in self.layers:
output = mod(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
past=past
)
layer_states.append(output[0])
if self.norm is not None:
output = self.norm(output)
return layer_states, output
if need_weights:
assert not return_layer_states
layer_attn = [] # layers' output
output = src
for mod in self.layers:
output = mod(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
need_weights=True,
past=past
)
layer_attn.append(output[1])
if self.norm is not None:
output = self.norm(output)
return layer_attn, output
output = src
all_present = []
for n_layer, mod in enumerate(self.layers):
output = mod(
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past=None if past is None else past[n_layer]
)
if isinstance(output, list):
output, present = output
all_present.append(present)
if self.norm is not None:
output = self.norm(output)
if all_present != []:
all_present = torch.stack(all_present, dim=0) # (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
output = [output, all_present]
return output
class TransformerDecoderLayer(nn.Module):
__constants__ = ["batch_first", "norm_first"]
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
linear1_self_attention_cls: nn.Module = nn.Linear,
linear2_self_attention_cls: nn.Module = nn.Linear,
linear1_feedforward_cls: nn.Module = nn.Linear,
linear2_feedforward_cls: nn.Module = nn.Linear,
batch_first: bool = False,
norm_first: bool = False,
device=None,
dtype=None,
layer_norm_cls: nn.Module = LayerNorm,
layer_norm_eps: float = 1e-5,
adaptive_layer_norm=False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(TransformerDecoderLayer, self).__init__()
self.self_attn = MultiheadAttention(
d_model,
nhead,
dropout=dropout,
batch_first=batch_first,
linear1_cls=linear1_self_attention_cls,
linear2_cls=linear2_self_attention_cls,
**factory_kwargs,
)
self.multihead_attn = MultiheadAttention(
d_model,
nhead,
dropout=dropout,
batch_first=batch_first,
linear1_cls=linear1_self_attention_cls,
linear2_cls=linear2_self_attention_cls,
**factory_kwargs,
)
# Implementation of Feedforward model
self.linear1 = linear1_feedforward_cls(
d_model, dim_feedforward, **factory_kwargs
)
self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(
dim_feedforward, d_model, **factory_kwargs
)
self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
self.activation = _get_activation_fn(activation)
elif isinstance(activation, partial):
self.activation = activation(d_model)
elif activation == BalancedDoubleSwish:
self.activation = BalancedDoubleSwish(d_model)
else:
self.activation = activation
if adaptive_layer_norm:
norm1 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
norm2 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
norm3 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
self.norm3 = AdaptiveLayerNorm(d_model, norm3)
else:
self.norm1 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
self.norm2 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
if layer_norm_cls == IdentityNorm:
self.norm3 = BalancedBasicNorm(
d_model, eps=layer_norm_eps, **factory_kwargs
)
else:
self.norm3 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
def forward(
self,
tgt: Tensor,
memory: Tensor,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the inputs (and mask) through the decoder layer.
Args:
tgt: the sequence to the decoder layer (required).
memory: the sequence from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
memory_mask: the mask for the memory sequence (optional).
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
memory_key_padding_mask: the mask for the memory keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
tgt_is_tuple = False
if isinstance(tgt, tuple):
x, stage_embedding = tgt
tgt_is_tuple = True
else:
x, stage_embedding = tgt, None
if self.norm_first:
x = x + self._sa_block(
self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask
)
x = x + self._mha_block(
self.norm2(x, stage_embedding),
memory,
memory_mask,
memory_key_padding_mask,
)
x = x + self._ff_block(self.norm3(x, stage_embedding))
else:
x = self.norm1(
x + self._sa_block(x, tgt_mask, tgt_key_padding_mask),
stage_embedding,
)
x = self.norm2(
x
+ self._mha_block(
x, memory, memory_mask, memory_key_padding_mask
),
stage_embedding,
)
x = self.norm3(x + self._ff_block(x), stage_embedding)
if tgt_is_tuple:
return (x, stage_embedding)
return x
# self-attention block
def _sa_block(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
) -> Tensor:
x = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]
return self.dropout1(x)
# multihead attention block
def _mha_block(
self,
x: Tensor,
mem: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
) -> Tensor:
x = self.multihead_attn(
x,
mem,
mem,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]
return self.dropout2(x)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout3(x)
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
raise RuntimeError(
"activation should be relu/gelu, not {}".format(activation)
)

37
models/modules/utils.py Normal file
View File

@ -0,0 +1,37 @@
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py, modified by Puyuan Peng
import torch
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""
Args:
lengths:
A 1-D tensor containing sentence lengths.
max_len:
The length of masks.
Returns:
Return a 2-D bool tensor, where masked positions
are filled with `True` and non-masked positions are
filled with `False`.
>>> lengths = torch.tensor([1, 3, 2, 5])
>>> make_pad_mask(lengths)
tensor([[False, True, True, True, True],
[False, False, False, True, True],
[False, False, True, True, True],
[False, False, False, False, False]])
"""
assert lengths.ndim == 1, lengths.ndim
max_len = max(max_len, lengths.max())
n = lengths.size(0)
seq_range = torch.arange(0, max_len, device=lengths.device)
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
return expaned_lengths >= lengths.unsqueeze(-1)
def generate_partial_autoregressive_mask(sz, start, end):
mask = torch.zeros(sz, sz).bool()
mask[start:end, start:end] = torch.triu(torch.ones(end-start, end-start,dtype=torch.bool), diagonal=1)
mask[:start, start:end] = True
mask[end:, start:end] = True
return mask

1402
models/voicecraft.py Normal file

File diff suppressed because it is too large Load Diff