|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from flash_attn import flash_attn_func |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
|
|
class FlashAttention(nn.Module): |
|
def __init__(self, dim: int, num_heads: int): |
|
super().__init__() |
|
self.dim = dim |
|
assert dim % num_heads == 0 |
|
self.num_heads = num_heads |
|
self.head_dim = dim // num_heads |
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=False) |
|
self.proj_out = torch.nn.Linear(dim, dim) |
|
|
|
def forward(self, x): |
|
B, N, C = x.shape |
|
qkv = self.qkv(x).view(B, N, 3, C) |
|
q, k, v = qkv.unbind(2) |
|
k = k.reshape(B, N, self.num_heads, self.head_dim) |
|
v = v.reshape(B, N, self.num_heads, self.head_dim) |
|
q = q.reshape(B, N, self.num_heads, self.head_dim) |
|
out = flash_attn_func(q, k, v) |
|
out = self.proj_out(out.view(B, N, C)) |
|
return out |
|
|