Spaces:
Running
on
Zero
Running
on
Zero
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py | |
from typing import Optional, Tuple | |
import torch | |
from torch import Tensor | |
from torch.nn import Linear, Module | |
from torch.nn import functional as F | |
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ | |
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear | |
from torch.nn.parameter import Parameter | |
import logging | |
from typing import Callable, List, Optional, Tuple, Union | |
from typing import TYPE_CHECKING | |
if TYPE_CHECKING: | |
from torch.types import _dtype as DType | |
else: | |
# The JIT doesn't understand Union, nor torch.dtype here | |
DType = int | |
def _canonical_mask( | |
mask: Optional[Tensor], | |
mask_name: str, | |
other_type: Optional[DType], | |
other_name: str, | |
target_type: DType, | |
check_other: bool = True, | |
) -> Optional[Tensor]: | |
if mask is not None: | |
_mask_dtype = mask.dtype | |
_mask_is_float = torch.is_floating_point(mask) | |
if _mask_dtype != torch.bool and not _mask_is_float: | |
raise AssertionError( | |
f"only bool and floating types of {mask_name} are supported") | |
if check_other and other_type is not None: | |
if _mask_dtype != other_type: | |
warnings.warn( | |
f"Support for mismatched {mask_name} and {other_name} " | |
"is deprecated. Use same type for both instead." | |
) | |
if not _mask_is_float: | |
mask = ( | |
torch.zeros_like(mask, dtype=target_type) | |
.masked_fill_(mask, float("-inf")) | |
) | |
return mask | |
def _in_projection_packed( | |
q: Tensor, | |
k: Tensor, | |
v: Tensor, | |
w: Tensor, | |
b: Optional[Tensor] = None, | |
) -> List[Tensor]: | |
r""" | |
Performs the in-projection step of the attention operation, using packed weights. | |
Output is a triple containing projection tensors for query, key and value. | |
Args: | |
q, k, v: query, key and value tensors to be projected. For self-attention, | |
these are typically the same tensor; for encoder-decoder attention, | |
k and v are typically the same tensor. (We take advantage of these | |
identities for performance if they are present.) Regardless, q, k and v | |
must share a common embedding dimension; otherwise their shapes may vary. | |
w: projection weights for q, k and v, packed into a single tensor. Weights | |
are packed along dimension 0, in q, k, v order. | |
b: optional projection biases for q, k and v, packed into a single tensor | |
in q, k, v order. | |
Shape: | |
Inputs: | |
- q: :math:`(..., E)` where E is the embedding dimension | |
- k: :math:`(..., E)` where E is the embedding dimension | |
- v: :math:`(..., E)` where E is the embedding dimension | |
- w: :math:`(E * 3, E)` where E is the embedding dimension | |
- b: :math:`E * 3` where E is the embedding dimension | |
Output: | |
- in output list :math:`[q', k', v']`, each output tensor will have the | |
same shape as the corresponding input tensor. | |
""" | |
E = q.size(-1) | |
if k is v: | |
if q is k: | |
# self-attention | |
proj = F.linear(q, w, b) | |
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() | |
proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() | |
return proj[0], proj[1], proj[2] | |
else: | |
# encoder-decoder attention | |
w_q, w_kv = w.split([E, E * 2]) | |
if b is None: | |
b_q = b_kv = None | |
else: | |
b_q, b_kv = b.split([E, E * 2]) | |
q_proj = F.linear(q, w_q, b_q) | |
kv_proj = F.linear(k, w_kv, b_kv) | |
# reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk() | |
kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() | |
return (q_proj, kv_proj[0], kv_proj[1]) | |
else: | |
w_q, w_k, w_v = w.chunk(3) | |
if b is None: | |
b_q = b_k = b_v = None | |
else: | |
b_q, b_k, b_v = b.chunk(3) | |
return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v) | |
def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]: | |
if input is None: | |
return None | |
elif isinstance(input, torch.Tensor): | |
return input.dtype | |
raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor") | |
def rotate_half(x): | |
x1 = x[..., :x.shape[-1] // 2] | |
x2 = x[..., x.shape[-1] // 2:] | |
return torch.cat([-x2, x1], dim=-1) | |
def apply_rotary_pos_emb(q, k, q_sinu=None, k_sinu=None, sinu=None, unsqueeze_dim=1, args=None, q_offset=0): | |
if sinu is not None: | |
q_emb = q * sinu['cos'][:, q_offset:q_offset+q.shape[2]].unsqueeze(unsqueeze_dim) + rotate_half(q) * sinu['sin'][:, q_offset:q_offset+q.shape[2]].unsqueeze(unsqueeze_dim) | |
k_emb = k * sinu['cos'][:, :k.shape[2]].unsqueeze(unsqueeze_dim) + rotate_half(k) * sinu['sin'][:, :k.shape[2]].unsqueeze(unsqueeze_dim) | |
if q_sinu is not None: | |
assert sinu is None, "sinu must be None" | |
q_emb = q * q_sinu['cos'][:, :, q_offset:q_offset+q.shape[2]] + rotate_half(q) * q_sinu['sin'][:, :, q_offset:q_offset+q.shape[2]] | |
k_emb = k * k_sinu['cos'][:, :, :k.shape[2]] + rotate_half(k) * k_sinu['sin'][:, :, :k.shape[2]] | |
# else: | |
# assert freqs is not None, "freqs must be provided" | |
# assert key_lens is not None, "key_lens must be provided" | |
# assert query_lens is not None, "query_lens must be provided" | |
# # key_multiple | |
# assert key_lens.ndim==1, key_lens | |
# assert query_lens.ndim==1, query_lens | |
# q_lens_expanded = query_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1] | |
# k_lens_expanded = key_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1] | |
# query_ids_multiple = q_lens_expanded / (q_lens_expanded - 1) | |
# key_ids_multiple = k_lens_expanded / (k_lens_expanded - 1) | |
# # freqs.shape [1, x_len_max, d] | |
# # torch.set_printoptions(edgeitems=200) | |
# # logging.info(f"{freqs[:, :q.shape[2]]=}") | |
# # logging.info(f"{query_ids_multiple=}") | |
# # logging.info(f"{key_ids_multiple=}") | |
# # print(f"q.shape: {q.shape}") | |
# # print(f"q_offset: {q_offset}") | |
# # print(f"k.shape: {k.shape}") | |
# q_emb = freqs[:, q_offset:q_offset+q.shape[-2]] * query_ids_multiple # [B, q_len_max, d] | |
# k_emb = freqs[:, :k.shape[2]] * key_ids_multiple # [B, k_len_max, d] | |
# # logging.info(f"{q_emb[0, :, :5]=}") | |
# # logging.info(f"{k_emb[0, :, :5]=}") | |
# multiple = k_lens_expanded if multiple_key_length else q_lens_expanded | |
# if progress_no_multiple: | |
# multiple = 1 | |
# q_emb = q_emb / q_lens_expanded * multiple * progress_scale | |
# k_emb = k_emb / k_lens_expanded * multiple * progress_scale | |
# q_cos = q_emb.cos().unsqueeze(unsqueeze_dim) # [B, 1, q_len_max, d] # 1 is for nhead | |
# q_sin = q_emb.sin().unsqueeze(unsqueeze_dim) | |
# k_cos = k_emb.cos().unsqueeze(unsqueeze_dim) | |
# k_sin = k_emb.sin().unsqueeze(unsqueeze_dim) | |
# # # visualize rotary pos emb with dummy feature | |
# # q_tmp = torch.ones_like(q) | |
# # k_tmp = torch.ones_like(k) | |
# # q_tmp_emb = q_tmp * q_cos + rotate_half(q_tmp) * q_sin | |
# # k_tmp_emb = k_tmp * k_cos + rotate_half(k_tmp) * k_sin | |
# # sims = q_tmp_emb @ k_tmp_emb.transpose(-2, -1) | |
# # import matplotlib.pyplot as plt | |
# # for i, sim in enumerate(sims): | |
# # plt.imshow(sim[0][:query_lens[i], :key_lens[i]].detach().cpu().numpy()) | |
# # plt.savefig(f"sim{i}_head0.png") | |
# # plt.imshow(sim[5][:query_lens[i], :key_lens[i]].detach().cpu().numpy()) | |
# # plt.savefig(f"sim{i}_head5.png") | |
# q_emb = q * q_cos + rotate_half(q) * q_sin | |
# k_emb = k * k_cos + rotate_half(k) * k_sin | |
# # # visualize the real attention weights | |
# # sims = q_emb @ k_emb.transpose(-2, -1) | |
# # from datetime import datetime | |
# # from matplotlib import pyplot as plt | |
# # now = datetime.now() | |
# # for i, sim in enumerate(sims): | |
# # for ihead, si in enumerate(sim): | |
# # if query_lens[i] == key_lens[i]: | |
# # continue | |
# # plt.imshow(si[:query_lens[i], :key_lens[i]].detach().cpu().numpy()) | |
# # plt.savefig(f"sample{i}_head{ihead}_{now.strftime('%Y-%m-%d_%H-%M-%S')}.png") | |
return q_emb, k_emb | |
class MultiheadAttention(Module): | |
r"""Allows the model to jointly attend to information | |
from different representation subspaces as described in the paper: | |
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_. | |
Multi-Head Attention is defined as: | |
.. math:: | |
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O | |
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. | |
``forward()`` will use a special optimized implementation if all of the following | |
conditions are met: | |
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This | |
restriction will be loosened in the future.) | |
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` | |
- training is disabled (using ``.eval()``) | |
- dropout is 0 | |
- ``add_bias_kv`` is ``False`` | |
- ``add_zero_attn`` is ``False`` | |
- ``batch_first`` is ``True`` and the input is batched | |
- ``kdim`` and ``vdim`` are equal to ``embed_dim`` | |
- at most one of ``key_padding_mask`` or ``attn_mask`` is passed | |
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask`` | |
nor ``attn_mask`` is passed | |
If the optimized implementation is in use, a | |
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for | |
``query``/``key``/``value`` to represent padding more efficiently than using a | |
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ | |
will be returned, and an additional speedup proportional to the fraction of the input | |
that is padding can be expected. | |
Args: | |
embed_dim: Total dimension of the model. | |
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split | |
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). | |
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). | |
bias: If specified, adds bias to input / output projection layers. Default: ``True``. | |
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. | |
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. | |
Default: ``False``. | |
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). | |
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). | |
batch_first: If ``True``, then the input and output tensors are provided | |
as (batch, seq, feature). Default: ``False`` (seq, batch, feature). | |
Examples:: | |
>>> # xdoctest: +SKIP | |
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) | |
>>> attn_output, attn_output_weights = multihead_attn(query, key, value) | |
""" | |
__constants__ = ["batch_first"] | |
bias_k: Optional[torch.Tensor] | |
bias_v: Optional[torch.Tensor] | |
def __init__( | |
self, | |
embed_dim, | |
num_heads, | |
dropout=0.0, | |
bias=True, | |
add_bias_kv=False, | |
add_zero_attn=False, | |
kdim=None, | |
vdim=None, | |
batch_first=False, | |
linear1_cls=Linear, | |
linear2_cls=Linear, | |
device=None, | |
dtype=None, | |
) -> None: | |
factory_kwargs = {"device": device, "dtype": dtype} | |
super(MultiheadAttention, self).__init__() | |
self.embed_dim = embed_dim | |
self.kdim = kdim if kdim is not None else embed_dim | |
self.vdim = vdim if vdim is not None else embed_dim | |
self._qkv_same_embed_dim = ( | |
self.kdim == embed_dim and self.vdim == embed_dim | |
) | |
self.num_heads = num_heads | |
self.dropout = dropout | |
self.batch_first = batch_first | |
self.head_dim = embed_dim // num_heads | |
assert ( | |
self.head_dim * num_heads == self.embed_dim | |
), "embed_dim must be divisible by num_heads" | |
if add_bias_kv: | |
self.bias_k = Parameter( | |
torch.empty((1, 1, embed_dim), **factory_kwargs) | |
) | |
self.bias_v = Parameter( | |
torch.empty((1, 1, embed_dim), **factory_kwargs) | |
) | |
else: | |
self.bias_k = self.bias_v = None | |
if linear1_cls == Linear: | |
if not self._qkv_same_embed_dim: | |
self.q_proj_weight = Parameter( | |
torch.empty((embed_dim, embed_dim), **factory_kwargs) | |
) | |
self.k_proj_weight = Parameter( | |
torch.empty((embed_dim, self.kdim), **factory_kwargs) | |
) | |
self.v_proj_weight = Parameter( | |
torch.empty((embed_dim, self.vdim), **factory_kwargs) | |
) | |
self.register_parameter("in_proj_weight", None) | |
else: | |
# go down this route with music_gen | |
self.in_proj_weight = Parameter( | |
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) | |
) | |
self.register_parameter("q_proj_weight", None) | |
self.register_parameter("k_proj_weight", None) | |
self.register_parameter("v_proj_weight", None) | |
if bias: # True by default | |
self.in_proj_bias = Parameter( | |
torch.empty(3 * embed_dim, **factory_kwargs) | |
) | |
else: | |
self.register_parameter("in_proj_bias", None) | |
self.out_proj = NonDynamicallyQuantizableLinear( | |
embed_dim, embed_dim, bias=bias, **factory_kwargs | |
) | |
self._reset_parameters() | |
else: | |
if not self._qkv_same_embed_dim: | |
raise NotImplementedError | |
else: | |
self.in_proj_linear = linear1_cls( | |
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs | |
) | |
self.in_proj_weight = self.in_proj_linear.weight | |
self.register_parameter("q_proj_weight", None) | |
self.register_parameter("k_proj_weight", None) | |
self.register_parameter("v_proj_weight", None) | |
if bias: | |
self.in_proj_bias = self.in_proj_linear.bias | |
else: | |
self.register_parameter("in_proj_bias", None) | |
self.out_proj = linear2_cls( | |
embed_dim, embed_dim, bias=bias, **factory_kwargs | |
) | |
if self.bias_k is not None: | |
xavier_normal_(self.bias_k) | |
if self.bias_v is not None: | |
xavier_normal_(self.bias_v) | |
self.add_zero_attn = add_zero_attn | |
def _reset_parameters(self): | |
if self._qkv_same_embed_dim: | |
xavier_uniform_(self.in_proj_weight) | |
else: | |
xavier_uniform_(self.q_proj_weight) | |
xavier_uniform_(self.k_proj_weight) | |
xavier_uniform_(self.v_proj_weight) | |
if self.in_proj_bias is not None: | |
constant_(self.in_proj_bias, 0.0) | |
constant_(self.out_proj.bias, 0.0) | |
if self.bias_k is not None: | |
xavier_normal_(self.bias_k) | |
if self.bias_v is not None: | |
xavier_normal_(self.bias_v) | |
def __setstate__(self, state): | |
# Support loading old MultiheadAttention checkpoints generated by v1.1.0 | |
if "_qkv_same_embed_dim" not in state: | |
state["_qkv_same_embed_dim"] = True | |
super(MultiheadAttention, self).__setstate__(state) | |
def forward( | |
self, | |
query: Tensor, | |
key: Tensor, | |
value: Tensor, | |
key_padding_mask: Optional[Tensor] = None, | |
need_weights: bool = True, | |
attn_mask: Optional[Tensor] = None, | |
average_attn_weights: bool = True, | |
past: Optional[Tensor] = None, | |
q_sinu = None, | |
k_sinu = None, | |
sinu = None, | |
args = None, | |
q_offset = 0, | |
) -> Tuple[Tensor, Optional[Tensor]]: | |
r""" | |
Args: | |
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` | |
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, | |
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. | |
Queries are compared against key-value pairs to produce the output. | |
See "Attention Is All You Need" for more details. | |
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` | |
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, | |
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. | |
See "Attention Is All You Need" for more details. | |
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when | |
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source | |
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. | |
See "Attention Is All You Need" for more details. | |
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` | |
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. | |
Binary and byte masks are supported. | |
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for | |
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. | |
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. | |
Default: ``True``. | |
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape | |
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, | |
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be | |
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. | |
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the | |
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the | |
corresponding position is not allowed to attend. For a float mask, the mask values will be added to | |
the attention weight. | |
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across | |
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an | |
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) | |
sinu: for direct original rope positional encoding | |
freqs: for progress monitoring with rope positional encoding | |
q_offset: for progress monitoring with rope positional encoding, during inference when kvcache is on, most of the time query is only of length 1, so we need to offset the query to get the correct progress | |
Outputs: | |
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, | |
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, | |
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the | |
embedding dimension ``embed_dim``. | |
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, | |
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or | |
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and | |
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per | |
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. | |
.. note:: | |
`batch_first` argument is ignored for unbatched inputs. | |
""" | |
is_batched = query.dim() == 3 | |
if key_padding_mask is not None: | |
_kpm_dtype = key_padding_mask.dtype | |
if _kpm_dtype != torch.bool and not torch.is_floating_point( | |
key_padding_mask | |
): | |
raise AssertionError( | |
"only bool and floating types of key_padding_mask are supported" | |
) | |
why_not_fast_path = "" | |
if not is_batched: | |
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" | |
elif query is not key or key is not value: | |
# When lifting this restriction, don't forget to either | |
# enforce that the dtypes all match or test cases where | |
# they don't! | |
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" | |
elif ( | |
self.in_proj_bias is not None | |
and query.dtype != self.in_proj_bias.dtype | |
): | |
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" | |
elif ( | |
self.in_proj_weight is not None | |
and query.dtype != self.in_proj_weight.dtype | |
): | |
# this case will fail anyway, but at least they'll get a useful error message. | |
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" | |
elif self.training: | |
why_not_fast_path = "training is enabled" | |
elif not self.batch_first: | |
why_not_fast_path = "batch_first was not True" | |
elif self.bias_k is not None: | |
why_not_fast_path = "self.bias_k was not None" | |
elif self.bias_v is not None: | |
why_not_fast_path = "self.bias_v was not None" | |
elif self.dropout: | |
why_not_fast_path = f"dropout was {self.dropout}, required zero" | |
elif self.add_zero_attn: | |
why_not_fast_path = "add_zero_attn was enabled" | |
elif not self._qkv_same_embed_dim: | |
why_not_fast_path = "_qkv_same_embed_dim was not True" | |
elif attn_mask is not None: | |
why_not_fast_path = "attn_mask was not None" | |
elif query.is_nested and key_padding_mask is not None: | |
why_not_fast_path = ( | |
"key_padding_mask is not supported with NestedTensor input" | |
) | |
elif self.num_heads % 2 == 1: | |
why_not_fast_path = "num_heads is odd" | |
elif torch.is_autocast_enabled(): | |
why_not_fast_path = "autocast is enabled" | |
if not why_not_fast_path: | |
tensor_args = ( | |
query, | |
key, | |
value, | |
self.in_proj_weight, | |
self.in_proj_bias, | |
self.out_proj.weight, | |
self.out_proj.bias, | |
) | |
# We have to use list comprehensions below because TorchScript does not support | |
# generator expressions. | |
if torch.overrides.has_torch_function(tensor_args): | |
why_not_fast_path = "some Tensor argument has_torch_function" | |
elif not all( | |
[ | |
(x is None or x.is_cuda or "cpu" in str(x.device)) | |
for x in tensor_args | |
] | |
): | |
why_not_fast_path = ( | |
"some Tensor argument is neither CUDA nor CPU" | |
) | |
elif torch.is_grad_enabled() and any( | |
[x is not None and x.requires_grad for x in tensor_args] | |
): | |
why_not_fast_path = ( | |
"grad is enabled and at least one of query or the " | |
"input/output projection weights or biases requires_grad" | |
) | |
if not why_not_fast_path: | |
return torch._native_multi_head_attention( | |
query, | |
key, | |
value, | |
self.embed_dim, | |
self.num_heads, | |
self.in_proj_weight, | |
self.in_proj_bias, | |
self.out_proj.weight, | |
self.out_proj.bias, | |
key_padding_mask | |
if key_padding_mask is not None | |
else attn_mask, | |
need_weights, | |
average_attn_weights, | |
1 | |
if key_padding_mask is not None | |
else 0 | |
if attn_mask is not None | |
else None, | |
) | |
any_nested = query.is_nested or key.is_nested or value.is_nested | |
assert not any_nested, ( | |
"MultiheadAttention does not support NestedTensor outside of its fast path. " | |
+ f"The fast path was not hit because {why_not_fast_path}" | |
) | |
if self.batch_first and is_batched: | |
# make sure that the transpose op does not affect the "is" property | |
if key is value: | |
if query is key: | |
query = key = value = query.transpose(1, 0) | |
else: | |
query, key = [x.transpose(1, 0) for x in (query, key)] | |
value = key | |
else: | |
query, key, value = [ | |
x.transpose(1, 0) for x in (query, key, value) | |
] | |
if not self._qkv_same_embed_dim: | |
attn_output, attn_output_weights = F.multi_head_attention_forward( | |
query, | |
key, | |
value, | |
self.embed_dim, | |
self.num_heads, | |
self.in_proj_weight, | |
self.in_proj_bias, | |
self.bias_k, | |
self.bias_v, | |
self.add_zero_attn, | |
self.dropout, | |
self.out_proj.weight, | |
self.out_proj.bias, | |
training=self.training, | |
key_padding_mask=key_padding_mask, | |
need_weights=need_weights, | |
attn_mask=attn_mask, | |
use_separate_proj_weight=True, | |
q_proj_weight=self.q_proj_weight, | |
k_proj_weight=self.k_proj_weight, | |
v_proj_weight=self.v_proj_weight, | |
average_attn_weights=average_attn_weights, | |
) | |
else: | |
# music_gen should go down this route | |
# logging.info("using in_proj_weight") | |
# attn_output, attn_output_weights = F.multi_head_attention_forward( | |
# query, | |
# key, | |
# value, | |
# self.embed_dim, | |
# self.num_heads, | |
# self.in_proj_weight, | |
# self.in_proj_bias, | |
# self.bias_k, | |
# self.bias_v, | |
# self.add_zero_attn, | |
# self.dropout, | |
# self.out_proj.weight, | |
# self.out_proj.bias, | |
# training=self.training, | |
# key_padding_mask=key_padding_mask, | |
# need_weights=need_weights, | |
# attn_mask=attn_mask, | |
# average_attn_weights=average_attn_weights, | |
# ) | |
# re-write the self.attention here, to get k, v cache | |
tgt_len, bsz, embed_dim = query.shape | |
src_len, _, _ = key.shape | |
num_heads = self.num_heads | |
key_padding_mask = _canonical_mask( | |
mask=key_padding_mask, | |
mask_name="key_padding_mask", | |
other_type=_none_or_dtype(attn_mask), | |
other_name="attn_mask", | |
target_type=query.dtype | |
) | |
attn_mask = _canonical_mask( | |
mask=attn_mask, | |
mask_name="attn_mask", | |
other_type=None, | |
other_name="", | |
target_type=query.dtype, | |
check_other=False, | |
) | |
head_dim = self.embed_dim // self.num_heads | |
assert head_dim * self.num_heads == self.embed_dim, f"embed_dim {self.embed_dim} not divisible by num_heads {self.num_heads}" | |
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}" | |
q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias) | |
# k_present, v_present = k, v | |
# | |
# reshape q, k, v for multihead attention and make em batch first | |
# | |
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) | |
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) | |
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) # (bsz * num_heads, src_len, head_dim) | |
src_len = k.size(1) | |
if past is not None and past.ndim > 2: | |
expected_src_len = src_len + past[0].shape[-2] | |
else: | |
expected_src_len = src_len | |
# ensure attn_mask's dim is 3 | |
if attn_mask is not None: | |
if attn_mask.dim() == 2: | |
correct_2d_size = (tgt_len, expected_src_len) | |
if attn_mask.shape != correct_2d_size: | |
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.") | |
attn_mask = attn_mask.unsqueeze(0) | |
elif attn_mask.dim() == 3: | |
correct_3d_size = (bsz * num_heads, tgt_len, expected_src_len) | |
if attn_mask.shape != correct_3d_size: | |
raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.") | |
else: | |
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") | |
if key_padding_mask is not None: | |
assert key_padding_mask.shape == (bsz, expected_src_len), \ | |
f"expecting key_padding_mask shape of {(bsz, expected_src_len)}, but got {key_padding_mask.shape}" | |
key_padding_mask = key_padding_mask.view(bsz, 1, 1, expected_src_len). \ | |
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, expected_src_len) | |
if attn_mask is None: | |
attn_mask = key_padding_mask | |
else: | |
attn_mask = attn_mask + key_padding_mask | |
if not self.training: | |
dropout_p = 0.0 | |
else: | |
dropout_p = self.dropout | |
if need_weights: | |
raise NotImplementedError("need_weights not implemented for music_gen") | |
# B, Nt, E = q.shape | |
# q_scaled = q / math.sqrt(E) | |
# assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights" | |
# if attn_mask is not None: | |
# attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1)) | |
# else: | |
# attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1)) | |
# attn_output_weights = softmax(attn_output_weights, dim=-1) | |
# if dropout_p > 0.0: | |
# attn_output_weights = dropout(attn_output_weights, p=dropout_p) | |
# attn_output = torch.bmm(attn_output_weights, v) | |
# attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) | |
# attn_output = linear(attn_output, out_proj_weight, out_proj_bias) | |
# attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) | |
# # optionally average attention weights over heads | |
# attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) | |
# if average_attn_weights: | |
# attn_output_weights = attn_output_weights.mean(dim=1) | |
# if not is_batched: | |
# # squeeze the output if input was unbatched | |
# attn_output = attn_output.squeeze(1) | |
# attn_output_weights = attn_output_weights.squeeze(0) | |
# return attn_output, attn_output_weights | |
else: | |
# attn_mask can be either (L,S) or (N*num_heads, L, S) | |
# if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S) | |
# in order to match the input for SDPA of (N, num_heads, L, S) | |
if attn_mask is not None: | |
if attn_mask.size(0) == 1 and attn_mask.dim() == 3: | |
attn_mask = attn_mask.unsqueeze(0) | |
else: | |
attn_mask = attn_mask.view(bsz, num_heads, -1, expected_src_len) | |
q = q.view(bsz, num_heads, tgt_len, head_dim) | |
k = k.view(bsz, num_heads, src_len, head_dim) | |
v = v.view(bsz, num_heads, src_len, head_dim) | |
# logging.info(f"shape of past: {past.shape}") | |
if past is not None: | |
present = torch.stack([k, v], dim=0) # (2, bsz, num_heads, src_len, head_dim) | |
if past.ndim > 2: # this means we use kvcache, otherwise we just pass in a placeholder, but not actually using kvcache | |
pk, pv = past | |
k = torch.cat([pk, k], dim=-2) | |
v = torch.cat([pv, v], dim=-2) | |
else: | |
present = None | |
# when using kvcache, need to offset postion of q when applying rotary pos emb | |
# here we assume that this kvcache is only used in self-attention, and therefore k and q always have the same seq_len | |
# rope positional encoding | |
if sinu is not None: | |
# direct rotary | |
# logging.info("perform rotary positional encoding") | |
q, k = apply_rotary_pos_emb(q, k, sinu=sinu, args = args, q_offset=q_offset) | |
if q_sinu is not None: | |
assert sinu is None, "sinu and q_sinu cannot be used together" | |
assert k_sinu is not None, "k_sinu must be provided" | |
q, k = apply_rotary_pos_emb(q, k, q_sinu=q_sinu, k_sinu=k_sinu, args = args, q_offset=q_offset) | |
# if self.training and it's cross attention, will get attention_weights | |
if args != None and self.training and getattr(args, "attention_alignment_loss", 0) and not (query is key): | |
attention_weights = q @ k.transpose(-1, -2) | |
else: | |
attention_weights = None | |
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal=False) | |
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) | |
attn_output = F.linear(attn_output, self.out_proj.weight, self.out_proj.bias) | |
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) | |
if not is_batched: | |
# squeeze the output if input was unbatched | |
attn_output = attn_output.squeeze(1) | |
# if self.training: | |
# return attn_output, None | |
# else: | |
# return (attn_output, present), None | |
# harded coded, the code do not support returning attn weigths yet | |
attn_output_weights=None | |
if self.batch_first and is_batched: | |
if attention_weights != None: | |
return {"attn_output": attn_output.transpose(1, 0), "attention_weights": attention_weights}, present | |
return attn_output.transpose(1, 0), present | |
else: | |
if attention_weights != None: | |
return {"attn_output": attn_output, "attention_weights": attention_weights}, present | |
return attn_output, present | |