654 lines
30 KiB
Python
654 lines
30 KiB
Python
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py, modified by Puyuan Peng, 2024
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.nn import Linear, Module
|
|
from torch.nn import functional as F
|
|
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
|
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
|
from torch.nn.parameter import Parameter
|
|
import logging
|
|
from typing import Callable, List, Optional, Tuple, Union
|
|
from typing import TYPE_CHECKING
|
|
if TYPE_CHECKING:
|
|
from torch.types import _dtype as DType
|
|
else:
|
|
# The JIT doesn't understand Union, nor torch.dtype here
|
|
DType = int
|
|
|
|
def _canonical_mask(
|
|
mask: Optional[Tensor],
|
|
mask_name: str,
|
|
other_type: Optional[DType],
|
|
other_name: str,
|
|
target_type: DType,
|
|
check_other: bool = True,
|
|
) -> Optional[Tensor]:
|
|
|
|
if mask is not None:
|
|
_mask_dtype = mask.dtype
|
|
_mask_is_float = torch.is_floating_point(mask)
|
|
if _mask_dtype != torch.bool and not _mask_is_float:
|
|
raise AssertionError(
|
|
f"only bool and floating types of {mask_name} are supported")
|
|
if check_other and other_type is not None:
|
|
if _mask_dtype != other_type:
|
|
warnings.warn(
|
|
f"Support for mismatched {mask_name} and {other_name} "
|
|
"is deprecated. Use same type for both instead."
|
|
)
|
|
if not _mask_is_float:
|
|
mask = (
|
|
torch.zeros_like(mask, dtype=target_type)
|
|
.masked_fill_(mask, float("-inf"))
|
|
)
|
|
return mask
|
|
|
|
def _in_projection_packed(
|
|
q: Tensor,
|
|
k: Tensor,
|
|
v: Tensor,
|
|
w: Tensor,
|
|
b: Optional[Tensor] = None,
|
|
) -> List[Tensor]:
|
|
r"""
|
|
Performs the in-projection step of the attention operation, using packed weights.
|
|
Output is a triple containing projection tensors for query, key and value.
|
|
|
|
Args:
|
|
q, k, v: query, key and value tensors to be projected. For self-attention,
|
|
these are typically the same tensor; for encoder-decoder attention,
|
|
k and v are typically the same tensor. (We take advantage of these
|
|
identities for performance if they are present.) Regardless, q, k and v
|
|
must share a common embedding dimension; otherwise their shapes may vary.
|
|
w: projection weights for q, k and v, packed into a single tensor. Weights
|
|
are packed along dimension 0, in q, k, v order.
|
|
b: optional projection biases for q, k and v, packed into a single tensor
|
|
in q, k, v order.
|
|
|
|
Shape:
|
|
Inputs:
|
|
- q: :math:`(..., E)` where E is the embedding dimension
|
|
- k: :math:`(..., E)` where E is the embedding dimension
|
|
- v: :math:`(..., E)` where E is the embedding dimension
|
|
- w: :math:`(E * 3, E)` where E is the embedding dimension
|
|
- b: :math:`E * 3` where E is the embedding dimension
|
|
|
|
Output:
|
|
- in output list :math:`[q', k', v']`, each output tensor will have the
|
|
same shape as the corresponding input tensor.
|
|
"""
|
|
E = q.size(-1)
|
|
if k is v:
|
|
if q is k:
|
|
# self-attention
|
|
proj = F.linear(q, w, b)
|
|
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
|
|
proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
|
|
return proj[0], proj[1], proj[2]
|
|
else:
|
|
# encoder-decoder attention
|
|
w_q, w_kv = w.split([E, E * 2])
|
|
if b is None:
|
|
b_q = b_kv = None
|
|
else:
|
|
b_q, b_kv = b.split([E, E * 2])
|
|
q_proj = F.linear(q, w_q, b_q)
|
|
kv_proj = F.linear(k, w_kv, b_kv)
|
|
# reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
|
|
kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
|
|
return (q_proj, kv_proj[0], kv_proj[1])
|
|
else:
|
|
w_q, w_k, w_v = w.chunk(3)
|
|
if b is None:
|
|
b_q = b_k = b_v = None
|
|
else:
|
|
b_q, b_k, b_v = b.chunk(3)
|
|
return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
|
|
|
|
def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
|
|
if input is None:
|
|
return None
|
|
elif isinstance(input, torch.Tensor):
|
|
return input.dtype
|
|
raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
|
|
class MultiheadAttention(Module):
|
|
r"""Allows the model to jointly attend to information
|
|
from different representation subspaces as described in the paper:
|
|
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
|
|
|
Multi-Head Attention is defined as:
|
|
|
|
.. math::
|
|
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
|
|
|
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
|
|
|
|
``forward()`` will use a special optimized implementation if all of the following
|
|
conditions are met:
|
|
|
|
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
|
|
restriction will be loosened in the future.)
|
|
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
|
|
- training is disabled (using ``.eval()``)
|
|
- dropout is 0
|
|
- ``add_bias_kv`` is ``False``
|
|
- ``add_zero_attn`` is ``False``
|
|
- ``batch_first`` is ``True`` and the input is batched
|
|
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
|
|
- at most one of ``key_padding_mask`` or ``attn_mask`` is passed
|
|
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
|
|
nor ``attn_mask`` is passed
|
|
|
|
If the optimized implementation is in use, a
|
|
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
|
|
``query``/``key``/``value`` to represent padding more efficiently than using a
|
|
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
|
|
will be returned, and an additional speedup proportional to the fraction of the input
|
|
that is padding can be expected.
|
|
|
|
Args:
|
|
embed_dim: Total dimension of the model.
|
|
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
|
|
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
|
|
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
|
|
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
|
|
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
|
|
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
|
|
Default: ``False``.
|
|
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
|
|
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
|
|
batch_first: If ``True``, then the input and output tensors are provided
|
|
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
|
|
|
Examples::
|
|
|
|
>>> # xdoctest: +SKIP
|
|
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
|
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
|
|
|
"""
|
|
__constants__ = ["batch_first"]
|
|
bias_k: Optional[torch.Tensor]
|
|
bias_v: Optional[torch.Tensor]
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dim,
|
|
num_heads,
|
|
dropout=0.0,
|
|
bias=True,
|
|
add_bias_kv=False,
|
|
add_zero_attn=False,
|
|
kdim=None,
|
|
vdim=None,
|
|
batch_first=False,
|
|
linear1_cls=Linear,
|
|
linear2_cls=Linear,
|
|
device=None,
|
|
dtype=None,
|
|
) -> None:
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
super(MultiheadAttention, self).__init__()
|
|
self.embed_dim = embed_dim
|
|
self.kdim = kdim if kdim is not None else embed_dim
|
|
self.vdim = vdim if vdim is not None else embed_dim
|
|
self._qkv_same_embed_dim = (
|
|
self.kdim == embed_dim and self.vdim == embed_dim
|
|
)
|
|
|
|
self.num_heads = num_heads
|
|
self.dropout = dropout
|
|
self.batch_first = batch_first
|
|
self.head_dim = embed_dim // num_heads
|
|
assert (
|
|
self.head_dim * num_heads == self.embed_dim
|
|
), "embed_dim must be divisible by num_heads"
|
|
|
|
if add_bias_kv:
|
|
self.bias_k = Parameter(
|
|
torch.empty((1, 1, embed_dim), **factory_kwargs)
|
|
)
|
|
self.bias_v = Parameter(
|
|
torch.empty((1, 1, embed_dim), **factory_kwargs)
|
|
)
|
|
else:
|
|
self.bias_k = self.bias_v = None
|
|
|
|
if linear1_cls == Linear:
|
|
if not self._qkv_same_embed_dim:
|
|
self.q_proj_weight = Parameter(
|
|
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
|
)
|
|
self.k_proj_weight = Parameter(
|
|
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
|
)
|
|
self.v_proj_weight = Parameter(
|
|
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
|
)
|
|
self.register_parameter("in_proj_weight", None)
|
|
else:
|
|
# go down this route with voicecraft
|
|
self.in_proj_weight = Parameter(
|
|
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
|
)
|
|
self.register_parameter("q_proj_weight", None)
|
|
self.register_parameter("k_proj_weight", None)
|
|
self.register_parameter("v_proj_weight", None)
|
|
|
|
if bias: # True by default
|
|
self.in_proj_bias = Parameter(
|
|
torch.empty(3 * embed_dim, **factory_kwargs)
|
|
)
|
|
else:
|
|
self.register_parameter("in_proj_bias", None)
|
|
self.out_proj = NonDynamicallyQuantizableLinear(
|
|
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
|
)
|
|
|
|
self._reset_parameters()
|
|
else:
|
|
if not self._qkv_same_embed_dim:
|
|
raise NotImplementedError
|
|
else:
|
|
self.in_proj_linear = linear1_cls(
|
|
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
|
)
|
|
self.in_proj_weight = self.in_proj_linear.weight
|
|
|
|
self.register_parameter("q_proj_weight", None)
|
|
self.register_parameter("k_proj_weight", None)
|
|
self.register_parameter("v_proj_weight", None)
|
|
|
|
if bias:
|
|
self.in_proj_bias = self.in_proj_linear.bias
|
|
else:
|
|
self.register_parameter("in_proj_bias", None)
|
|
|
|
self.out_proj = linear2_cls(
|
|
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
|
)
|
|
|
|
if self.bias_k is not None:
|
|
xavier_normal_(self.bias_k)
|
|
if self.bias_v is not None:
|
|
xavier_normal_(self.bias_v)
|
|
|
|
self.add_zero_attn = add_zero_attn
|
|
|
|
def _reset_parameters(self):
|
|
if self._qkv_same_embed_dim:
|
|
xavier_uniform_(self.in_proj_weight)
|
|
else:
|
|
xavier_uniform_(self.q_proj_weight)
|
|
xavier_uniform_(self.k_proj_weight)
|
|
xavier_uniform_(self.v_proj_weight)
|
|
|
|
if self.in_proj_bias is not None:
|
|
constant_(self.in_proj_bias, 0.0)
|
|
constant_(self.out_proj.bias, 0.0)
|
|
|
|
if self.bias_k is not None:
|
|
xavier_normal_(self.bias_k)
|
|
if self.bias_v is not None:
|
|
xavier_normal_(self.bias_v)
|
|
|
|
def __setstate__(self, state):
|
|
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
|
if "_qkv_same_embed_dim" not in state:
|
|
state["_qkv_same_embed_dim"] = True
|
|
|
|
super(MultiheadAttention, self).__setstate__(state)
|
|
|
|
def forward(
|
|
self,
|
|
query: Tensor,
|
|
key: Tensor,
|
|
value: Tensor,
|
|
key_padding_mask: Optional[Tensor] = None,
|
|
need_weights: bool = True,
|
|
attn_mask: Optional[Tensor] = None,
|
|
average_attn_weights: bool = True,
|
|
past: Optional[Tensor] = None,
|
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
|
r"""
|
|
Args:
|
|
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
|
|
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
|
|
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
|
|
Queries are compared against key-value pairs to produce the output.
|
|
See "Attention Is All You Need" for more details.
|
|
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
|
|
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
|
|
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
|
|
See "Attention Is All You Need" for more details.
|
|
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
|
|
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
|
|
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
|
|
See "Attention Is All You Need" for more details.
|
|
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
|
|
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
|
|
Binary and byte masks are supported.
|
|
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
|
|
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
|
|
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
|
|
Default: ``True``.
|
|
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
|
|
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
|
|
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
|
|
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
|
|
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
|
|
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
|
|
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
|
|
the attention weight.
|
|
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
|
|
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
|
|
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
|
|
|
|
Outputs:
|
|
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
|
|
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
|
|
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
|
|
embedding dimension ``embed_dim``.
|
|
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
|
|
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
|
|
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
|
|
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
|
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
|
|
|
|
.. note::
|
|
`batch_first` argument is ignored for unbatched inputs.
|
|
"""
|
|
is_batched = query.dim() == 3
|
|
if key_padding_mask is not None:
|
|
_kpm_dtype = key_padding_mask.dtype
|
|
if _kpm_dtype != torch.bool and not torch.is_floating_point(
|
|
key_padding_mask
|
|
):
|
|
raise AssertionError(
|
|
"only bool and floating types of key_padding_mask are supported"
|
|
)
|
|
why_not_fast_path = ""
|
|
if not is_batched:
|
|
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
|
elif query is not key or key is not value:
|
|
# When lifting this restriction, don't forget to either
|
|
# enforce that the dtypes all match or test cases where
|
|
# they don't!
|
|
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
|
elif (
|
|
self.in_proj_bias is not None
|
|
and query.dtype != self.in_proj_bias.dtype
|
|
):
|
|
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
|
elif (
|
|
self.in_proj_weight is not None
|
|
and query.dtype != self.in_proj_weight.dtype
|
|
):
|
|
# this case will fail anyway, but at least they'll get a useful error message.
|
|
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
|
elif self.training:
|
|
why_not_fast_path = "training is enabled"
|
|
elif not self.batch_first:
|
|
why_not_fast_path = "batch_first was not True"
|
|
elif self.bias_k is not None:
|
|
why_not_fast_path = "self.bias_k was not None"
|
|
elif self.bias_v is not None:
|
|
why_not_fast_path = "self.bias_v was not None"
|
|
elif self.dropout:
|
|
why_not_fast_path = f"dropout was {self.dropout}, required zero"
|
|
elif self.add_zero_attn:
|
|
why_not_fast_path = "add_zero_attn was enabled"
|
|
elif not self._qkv_same_embed_dim:
|
|
why_not_fast_path = "_qkv_same_embed_dim was not True"
|
|
elif attn_mask is not None:
|
|
why_not_fast_path = "attn_mask was not None"
|
|
elif query.is_nested and key_padding_mask is not None:
|
|
why_not_fast_path = (
|
|
"key_padding_mask is not supported with NestedTensor input"
|
|
)
|
|
elif self.num_heads % 2 == 1:
|
|
why_not_fast_path = "num_heads is odd"
|
|
elif torch.is_autocast_enabled():
|
|
why_not_fast_path = "autocast is enabled"
|
|
|
|
if not why_not_fast_path:
|
|
tensor_args = (
|
|
query,
|
|
key,
|
|
value,
|
|
self.in_proj_weight,
|
|
self.in_proj_bias,
|
|
self.out_proj.weight,
|
|
self.out_proj.bias,
|
|
)
|
|
# We have to use list comprehensions below because TorchScript does not support
|
|
# generator expressions.
|
|
if torch.overrides.has_torch_function(tensor_args):
|
|
why_not_fast_path = "some Tensor argument has_torch_function"
|
|
elif not all(
|
|
[
|
|
(x is None or x.is_cuda or "cpu" in str(x.device))
|
|
for x in tensor_args
|
|
]
|
|
):
|
|
why_not_fast_path = (
|
|
"some Tensor argument is neither CUDA nor CPU"
|
|
)
|
|
elif torch.is_grad_enabled() and any(
|
|
[x is not None and x.requires_grad for x in tensor_args]
|
|
):
|
|
why_not_fast_path = (
|
|
"grad is enabled and at least one of query or the "
|
|
"input/output projection weights or biases requires_grad"
|
|
)
|
|
if not why_not_fast_path:
|
|
return torch._native_multi_head_attention(
|
|
query,
|
|
key,
|
|
value,
|
|
self.embed_dim,
|
|
self.num_heads,
|
|
self.in_proj_weight,
|
|
self.in_proj_bias,
|
|
self.out_proj.weight,
|
|
self.out_proj.bias,
|
|
key_padding_mask
|
|
if key_padding_mask is not None
|
|
else attn_mask,
|
|
need_weights,
|
|
average_attn_weights,
|
|
1
|
|
if key_padding_mask is not None
|
|
else 0
|
|
if attn_mask is not None
|
|
else None,
|
|
)
|
|
|
|
any_nested = query.is_nested or key.is_nested or value.is_nested
|
|
assert not any_nested, (
|
|
"MultiheadAttention does not support NestedTensor outside of its fast path. "
|
|
+ f"The fast path was not hit because {why_not_fast_path}"
|
|
)
|
|
|
|
if self.batch_first and is_batched:
|
|
# make sure that the transpose op does not affect the "is" property
|
|
if key is value:
|
|
if query is key:
|
|
query = key = value = query.transpose(1, 0)
|
|
else:
|
|
query, key = [x.transpose(1, 0) for x in (query, key)]
|
|
value = key
|
|
else:
|
|
query, key, value = [
|
|
x.transpose(1, 0) for x in (query, key, value)
|
|
]
|
|
|
|
if not self._qkv_same_embed_dim:
|
|
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
|
query,
|
|
key,
|
|
value,
|
|
self.embed_dim,
|
|
self.num_heads,
|
|
self.in_proj_weight,
|
|
self.in_proj_bias,
|
|
self.bias_k,
|
|
self.bias_v,
|
|
self.add_zero_attn,
|
|
self.dropout,
|
|
self.out_proj.weight,
|
|
self.out_proj.bias,
|
|
training=self.training,
|
|
key_padding_mask=key_padding_mask,
|
|
need_weights=need_weights,
|
|
attn_mask=attn_mask,
|
|
use_separate_proj_weight=True,
|
|
q_proj_weight=self.q_proj_weight,
|
|
k_proj_weight=self.k_proj_weight,
|
|
v_proj_weight=self.v_proj_weight,
|
|
average_attn_weights=average_attn_weights,
|
|
)
|
|
else:
|
|
# re-write the self.attention here, to get k, v cache
|
|
tgt_len, bsz, embed_dim = query.shape
|
|
src_len, _, _ = key.shape
|
|
num_heads = self.num_heads
|
|
key_padding_mask = _canonical_mask(
|
|
mask=key_padding_mask,
|
|
mask_name="key_padding_mask",
|
|
other_type=_none_or_dtype(attn_mask),
|
|
other_name="attn_mask",
|
|
target_type=query.dtype
|
|
)
|
|
attn_mask = _canonical_mask(
|
|
mask=attn_mask,
|
|
mask_name="attn_mask",
|
|
other_type=None,
|
|
other_name="",
|
|
target_type=query.dtype,
|
|
check_other=False,
|
|
)
|
|
head_dim = self.embed_dim // self.num_heads
|
|
assert head_dim * self.num_heads == self.embed_dim, f"embed_dim {self.embed_dim} not divisible by num_heads {self.num_heads}"
|
|
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
|
|
q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
|
|
# k_present, v_present = k, v
|
|
|
|
#
|
|
# reshape q, k, v for multihead attention and make em batch first
|
|
#
|
|
|
|
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
|
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
|
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) # (bsz * num_heads, src_len, head_dim)
|
|
src_len = k.size(1)
|
|
if past is not None and past.ndim > 2:
|
|
expected_src_len = src_len + past[0].shape[-2]
|
|
else:
|
|
expected_src_len = src_len
|
|
|
|
|
|
# ensure attn_mask's dim is 3
|
|
if attn_mask.dim() == 2:
|
|
correct_2d_size = (tgt_len, expected_src_len)
|
|
if attn_mask.shape != correct_2d_size:
|
|
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
|
|
attn_mask = attn_mask.unsqueeze(0)
|
|
elif attn_mask.dim() == 3:
|
|
correct_3d_size = (bsz * num_heads, tgt_len, expected_src_len)
|
|
if attn_mask.shape != correct_3d_size:
|
|
raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
|
|
else:
|
|
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
|
|
|
|
if key_padding_mask is not None:
|
|
assert key_padding_mask.shape == (bsz, expected_src_len), \
|
|
f"expecting key_padding_mask shape of {(bsz, expected_src_len)}, but got {key_padding_mask.shape}"
|
|
key_padding_mask = key_padding_mask.view(bsz, 1, 1, expected_src_len). \
|
|
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, expected_src_len)
|
|
if attn_mask is None:
|
|
attn_mask = key_padding_mask
|
|
else:
|
|
attn_mask = attn_mask + key_padding_mask
|
|
|
|
if not self.training:
|
|
dropout_p = 0.0
|
|
else:
|
|
dropout_p = self.dropout
|
|
|
|
if need_weights:
|
|
raise NotImplementedError("need_weights not implemented for voicecraft")
|
|
# B, Nt, E = q.shape
|
|
# q_scaled = q / math.sqrt(E)
|
|
|
|
# assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
|
|
|
|
# if attn_mask is not None:
|
|
# attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
|
|
# else:
|
|
# attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
|
|
# attn_output_weights = softmax(attn_output_weights, dim=-1)
|
|
# if dropout_p > 0.0:
|
|
# attn_output_weights = dropout(attn_output_weights, p=dropout_p)
|
|
|
|
# attn_output = torch.bmm(attn_output_weights, v)
|
|
|
|
# attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
|
# attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
|
# attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
|
|
|
# # optionally average attention weights over heads
|
|
# attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
|
# if average_attn_weights:
|
|
# attn_output_weights = attn_output_weights.mean(dim=1)
|
|
|
|
# if not is_batched:
|
|
# # squeeze the output if input was unbatched
|
|
# attn_output = attn_output.squeeze(1)
|
|
# attn_output_weights = attn_output_weights.squeeze(0)
|
|
# return attn_output, attn_output_weights
|
|
else:
|
|
# attn_mask can be either (L,S) or (N*num_heads, L, S)
|
|
# if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
|
|
# in order to match the input for SDPA of (N, num_heads, L, S)
|
|
if attn_mask is not None:
|
|
if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
|
|
attn_mask = attn_mask.unsqueeze(0)
|
|
else:
|
|
attn_mask = attn_mask.view(bsz, num_heads, -1, expected_src_len)
|
|
|
|
q = q.view(bsz, num_heads, tgt_len, head_dim)
|
|
k = k.view(bsz, num_heads, src_len, head_dim)
|
|
v = v.view(bsz, num_heads, src_len, head_dim)
|
|
# logging.info(f"shape of past: {past.shape}")
|
|
if past is not None:
|
|
present = torch.stack([k, v], dim=0) # (2, bsz, num_heads, src_len, head_dim)
|
|
if past.ndim > 2: # this means we use kvcache, otherwise we just pass in a placeholder, but not actually using kvcache
|
|
pk, pv = past
|
|
k = torch.cat([pk, k], dim=-2)
|
|
v = torch.cat([pv, v], dim=-2)
|
|
else:
|
|
present = None
|
|
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal=False)
|
|
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
|
|
|
attn_output = F.linear(attn_output, self.out_proj.weight, self.out_proj.bias)
|
|
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
|
if not is_batched:
|
|
# squeeze the output if input was unbatched
|
|
attn_output = attn_output.squeeze(1)
|
|
# if self.training:
|
|
# return attn_output, None
|
|
# else:
|
|
# return (attn_output, present), None
|
|
|
|
# harded coded, the code do not support returning attn weigths yet
|
|
attn_output_weights=None
|
|
if self.batch_first and is_batched:
|
|
return attn_output.transpose(1, 0), present
|
|
else:
|
|
return attn_output, present
|
|
|