File size: 6,118 Bytes
0d9a24b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import math
import torch
import torch.nn as nn
def build_vision_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, 'mm_projector_type', 'linear')
if projector_type == 'conv_adapter':
return ConvAdapter(config.mm_hidden_size, config.hidden_size, getattr(config, "mlp_hidden_dim", None))
elif projector_type == 'mlp_pixel_shuffle':
return MlpPixelShuffle(config.mm_hidden_size, config.hidden_size,
config.pixelshuffle_downsample_ratio, getattr(config, "mlp_hidden_dim", None))
elif projector_type == 'ovis_conv_adapter':
return OvisConvAdapter(config.mm_hidden_size, config.hidden_size, getattr(config, "mlp_hidden_dim", 32000),
getattr(config, "tokenize_function", "softmax"))
raise ValueError(f'Unknown projector type: {projector_type}')
class ConvAdapter(nn.Module):
def __init__(self, dim_in, dim_out, mlp_hidden_dim=None):
super().__init__()
self.mm_projector_type = 'conv_adapter'
if mlp_hidden_dim is None:
self.mlp = nn.Sequential(
nn.Linear(dim_in, dim_out),
nn.GELU(),
nn.Linear(dim_out, dim_out)
)
else:
self.mlp = nn.Sequential(
nn.Linear(dim_in, mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, dim_out)
)
self.conv = nn.Conv2d(dim_out, dim_out, kernel_size=(3, 3), stride=(2, 2), padding=1)
def forward(self, x):
"""
Args:
x (torch.Tensor): image features
shape (F, v, D)
Returns:
shape (F, n, D) where n is token_num that has been reduced
"""
x = self.mlp(x)
f, v, d = x.shape
s = int(math.sqrt(v - 1))
x = x[:, 1:, :] # remove cls_token
x = x.reshape(f, s, s, d).permute([0, 3, 1, 2])
x = self.conv(x)
x = x.permute([0, 2, 3, 1]).reshape(f, -1, d)
return x
class MlpPixelShuffle(nn.Module):
def __init__(self, dim_in, dim_out, pixelshuffle_downsample_ratio, mlp_hidden_dim=None):
super().__init__()
self.mm_projector_type = 'mlp_pixel_shuffle'
if mlp_hidden_dim is None:
self.mlp = nn.Sequential(
nn.Linear(int(dim_in * (pixelshuffle_downsample_ratio ** 2)), dim_out),
nn.GELU(),
nn.Linear(dim_out, dim_out)
)
else:
self.mlp = nn.Sequential(
nn.Linear(int(dim_in * (pixelshuffle_downsample_ratio ** 2)), mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, dim_out)
)
self.scale_factor = pixelshuffle_downsample_ratio
def pixel_shuffle(self, x, scale_factor=2):
# change scale_factor from float to int
n, w, h, c = x.size()
# N, W, H, C --> N, W, H / scale, C * scale
x = x.view(n, w, int(h / scale_factor), int(c * scale_factor))
# N, W, H / scale, C * scale --> N, H / scale, W, C * scale
x = x.permute(0, 2, 1, 3).contiguous()
# N, H / scale, W, C * scale --> N, H / scale, W / scale, C * (scale ** 2)
x = x.view(n, int(h / scale_factor), int(w / scale_factor),
int(c * (scale_factor * scale_factor)))
x = x.permute(0, 2, 1, 3).contiguous()
return x
def forward(self, x):
"""
Args:
x (torch.Tensor): image features
shape (F, v, D)
Returns:
shape (F, n, D) where n is token_num that has been reduced
"""
x = x[:, 1:, :] # remove cls_token
h = w = int(x.shape[1] ** 0.5)
x = x.view(x.shape[0], h, w, -1)
x = self.pixel_shuffle(x, self.scale_factor)
x = self.mlp(x)
x = x.view(x.shape[0],-1,x.shape[-1])
return x
class OvisConvAdapter(nn.Module):
def __init__(self, dim_in, dim_out, vocab_size, tokenize_function="softmax"):
super().__init__()
self.mm_projector_type = 'ovis_conv_adapter'
self.conv = nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), padding=1)
self.mlp = torch.nn.Sequential(
torch.nn.Linear(dim_in, vocab_size, bias=False),
torch.nn.LayerNorm(vocab_size)
)
self.embedding = torch.nn.Embedding(vocab_size, dim_out)
self.tokenize_function = tokenize_function
def tokenize(self, logits):
def st_argmax(y_soft, dim): # straight-through softmax
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(y_soft, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
return ret
if self.tokenize_function == 'softmax':
tokens = torch.nn.functional.softmax(logits, dim=-1)
elif self.tokenize_function == 'gumbel_argmax':
tokens = torch.nn.functional.gumbel_softmax(logits, tau=self.config.tau, hard=True)
elif self.tokenize_function == 'st_argmax':
tokens = st_argmax(logits, dim=-1)
else:
raise ValueError(
'Invalid `max_type`, expected softmax or gumbel_argmax or st_argmax,'
f' but got {self.config.tokenize_function}'
)
return tokens
def forward(self, x):
"""
Args:
x (torch.Tensor): image features
shape (F, v, D)
Returns:
shape (F, n, D) where n is token_num that has been reduced
"""
# conv
f, v, d = x.shape
s = int(math.sqrt(v - 1))
x = x[:, 1:, :] # remove cls_token
x = x.reshape(f, s, s, d).permute([0, 3, 1, 2])
x = self.conv(x)
x = x.permute([0, 2, 3, 1]).reshape(f, -1, d)
# tokenize
logits = self.mlp(x)
visual_tokens = self.tokenize(logits)
# get embeddings
out = torch.matmul(visual_tokens, self.embedding.weight)
return out
|