File size: 10,110 Bytes
d643072 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 |
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
# This file is modified from https://github.com/PixArt-alpha/PixArt-sigma
import torch
import torch.nn as nn
from timm.models.vision_transformer import Mlp
from diffusion.model.act import build_act, get_act_name
from diffusion.model.norms import build_norm, get_norm_name
from diffusion.model.utils import get_same_padding, val2tuple
class ConvLayer(nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
kernel_size=3,
stride=1,
dilation=1,
groups=1,
padding: int or None = None,
use_bias=False,
dropout=0.0,
norm="bn2d",
act="relu",
):
super().__init__()
if padding is None:
padding = get_same_padding(kernel_size)
padding *= dilation
self.in_dim = in_dim
self.out_dim = out_dim
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.groups = groups
self.padding = padding
self.use_bias = use_bias
self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None
self.conv = nn.Conv2d(
in_dim,
out_dim,
kernel_size=(kernel_size, kernel_size),
stride=(stride, stride),
padding=padding,
dilation=(dilation, dilation),
groups=groups,
bias=use_bias,
)
self.norm = build_norm(norm, num_features=out_dim)
self.act = build_act(act)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.dropout is not None:
x = self.dropout(x)
x = self.conv(x)
if self.norm:
x = self.norm(x)
if self.act:
x = self.act(x)
return x
class GLUMBConv(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int,
out_feature=None,
kernel_size=3,
stride=1,
padding: int or None = None,
use_bias=False,
norm=(None, None, None),
act=("silu", "silu", None),
dilation=1,
):
out_feature = out_feature or in_features
super().__init__()
use_bias = val2tuple(use_bias, 3)
norm = val2tuple(norm, 3)
act = val2tuple(act, 3)
self.glu_act = build_act(act[1], inplace=False)
self.inverted_conv = ConvLayer(
in_features,
hidden_features * 2,
1,
use_bias=use_bias[0],
norm=norm[0],
act=act[0],
)
self.depth_conv = ConvLayer(
hidden_features * 2,
hidden_features * 2,
kernel_size,
stride=stride,
groups=hidden_features * 2,
padding=padding,
use_bias=use_bias[1],
norm=norm[1],
act=None,
dilation=dilation,
)
self.point_conv = ConvLayer(
hidden_features,
out_feature,
1,
use_bias=use_bias[2],
norm=norm[2],
act=act[2],
)
# from IPython import embed; embed(header='debug dilate conv')
def forward(self, x: torch.Tensor, HW=None) -> torch.Tensor:
B, N, C = x.shape
if HW is None:
H = W = int(N**0.5)
else:
H, W = HW
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
x = self.inverted_conv(x)
x = self.depth_conv(x)
x, gate = torch.chunk(x, 2, dim=1)
gate = self.glu_act(gate)
x = x * gate
x = self.point_conv(x)
x = x.reshape(B, C, N).permute(0, 2, 1)
return x
class SlimGLUMBConv(GLUMBConv):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# 移除 self.inverted_conv 层
del self.inverted_conv
self.out_dim = self.point_conv.out_dim
def forward(self, x: torch.Tensor, HW=None) -> torch.Tensor:
B, N, C = x.shape
if HW is None:
H = W = int(N**0.5)
else:
H, W = HW
# 直接使用 x,跳过 self.inverted_conv 层的调用
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
# x = self.inverted_conv(x)
x = self.depth_conv(x)
x, gate = torch.chunk(x, 2, dim=1)
gate = self.glu_act(gate)
x = x * gate
x = self.point_conv(x)
x = x.reshape(B, self.out_dim, N).permute(0, 2, 1)
return x
class MBConvPreGLU(nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
kernel_size=3,
stride=1,
mid_dim=None,
expand=6,
padding: int or None = None,
use_bias=False,
norm=(None, None, "ln2d"),
act=("silu", "silu", None),
):
super().__init__()
use_bias = val2tuple(use_bias, 3)
norm = val2tuple(norm, 3)
act = val2tuple(act, 3)
mid_dim = mid_dim or round(in_dim * expand)
self.inverted_conv = ConvLayer(
in_dim,
mid_dim * 2,
1,
use_bias=use_bias[0],
norm=norm[0],
act=None,
)
self.glu_act = build_act(act[0], inplace=False)
self.depth_conv = ConvLayer(
mid_dim,
mid_dim,
kernel_size,
stride=stride,
groups=mid_dim,
padding=padding,
use_bias=use_bias[1],
norm=norm[1],
act=act[1],
)
self.point_conv = ConvLayer(
mid_dim,
out_dim,
1,
use_bias=use_bias[2],
norm=norm[2],
act=act[2],
)
def forward(self, x: torch.Tensor, HW=None) -> torch.Tensor:
B, N, C = x.shape
if HW is None:
H = W = int(N**0.5)
else:
H, W = HW
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
x = self.inverted_conv(x)
x, gate = torch.chunk(x, 2, dim=1)
gate = self.glu_act(gate)
x = x * gate
x = self.depth_conv(x)
x = self.point_conv(x)
x = x.reshape(B, C, N).permute(0, 2, 1)
return x
@property
def module_str(self) -> str:
_str = f"{self.depth_conv.kernel_size}{type(self).__name__}("
_str += f"in={self.inverted_conv.in_dim},mid={self.depth_conv.in_dim},out={self.point_conv.out_dim},s={self.depth_conv.stride}"
_str += (
f",norm={get_norm_name(self.inverted_conv.norm)}"
f"+{get_norm_name(self.depth_conv.norm)}"
f"+{get_norm_name(self.point_conv.norm)}"
)
_str += (
f",act={get_act_name(self.inverted_conv.act)}"
f"+{get_act_name(self.depth_conv.act)}"
f"+{get_act_name(self.point_conv.act)}"
)
_str += f",glu_act={get_act_name(self.glu_act)})"
return _str
class DWMlp(Mlp):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
bias=True,
drop=0.0,
kernel_size=3,
stride=1,
dilation=1,
padding=None,
):
super().__init__(
in_features=in_features,
hidden_features=hidden_features,
out_features=out_features,
act_layer=act_layer,
bias=bias,
drop=drop,
)
hidden_features = hidden_features or in_features
self.hidden_features = hidden_features
if padding is None:
padding = get_same_padding(kernel_size)
padding *= dilation
self.conv = nn.Conv2d(
hidden_features,
hidden_features,
kernel_size=(kernel_size, kernel_size),
stride=(stride, stride),
padding=padding,
dilation=(dilation, dilation),
groups=hidden_features,
bias=bias,
)
def forward(self, x, HW=None):
B, N, C = x.shape
if HW is None:
H = W = int(N**0.5)
else:
H, W = HW
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = x.reshape(B, H, W, self.hidden_features).permute(0, 3, 1, 2)
x = self.conv(x)
x = x.reshape(B, self.hidden_features, N).permute(0, 2, 1)
x = self.fc2(x)
x = self.drop2(x)
return x
class Mlp(Mlp):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.0):
super().__init__(
in_features=in_features,
hidden_features=hidden_features,
out_features=out_features,
act_layer=act_layer,
bias=bias,
drop=drop,
)
def forward(self, x, HW=None):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
if __name__ == "__main__":
model = GLUMBConv(
1152,
1152 * 4,
1152,
use_bias=(True, True, False),
norm=(None, None, None),
act=("silu", "silu", None),
).cuda()
input = torch.randn(4, 256, 1152).cuda()
output = model(input)
|