|
import math |
|
import torch |
|
import torch.nn as nn |
|
|
|
def build_vision_projector(config, delay_load=False, **kwargs): |
|
projector_type = getattr(config, 'mm_projector_type', 'linear') |
|
|
|
if projector_type == 'conv_adapter': |
|
return ConvAdapter(config.mm_hidden_size, config.hidden_size, getattr(config, "mlp_hidden_dim", None)) |
|
elif projector_type == 'mlp_pixel_shuffle': |
|
return MlpPixelShuffle(config.mm_hidden_size, config.hidden_size, |
|
config.pixelshuffle_downsample_ratio, getattr(config, "mlp_hidden_dim", None)) |
|
elif projector_type == 'ovis_conv_adapter': |
|
return OvisConvAdapter(config.mm_hidden_size, config.hidden_size, getattr(config, "mlp_hidden_dim", 32000), |
|
getattr(config, "tokenize_function", "softmax")) |
|
raise ValueError(f'Unknown projector type: {projector_type}') |
|
|
|
|
|
class ConvAdapter(nn.Module): |
|
def __init__(self, dim_in, dim_out, mlp_hidden_dim=None): |
|
super().__init__() |
|
self.mm_projector_type = 'conv_adapter' |
|
if mlp_hidden_dim is None: |
|
self.mlp = nn.Sequential( |
|
nn.Linear(dim_in, dim_out), |
|
nn.GELU(), |
|
nn.Linear(dim_out, dim_out) |
|
) |
|
else: |
|
self.mlp = nn.Sequential( |
|
nn.Linear(dim_in, mlp_hidden_dim), |
|
nn.GELU(), |
|
nn.Linear(mlp_hidden_dim, dim_out) |
|
) |
|
self.conv = nn.Conv2d(dim_out, dim_out, kernel_size=(3, 3), stride=(2, 2), padding=1) |
|
|
|
def forward(self, x): |
|
""" |
|
Args: |
|
x (torch.Tensor): image features |
|
shape (F, v, D) |
|
Returns: |
|
shape (F, n, D) where n is token_num that has been reduced |
|
""" |
|
x = self.mlp(x) |
|
|
|
f, v, d = x.shape |
|
s = int(math.sqrt(v - 1)) |
|
x = x[:, 1:, :] |
|
x = x.reshape(f, s, s, d).permute([0, 3, 1, 2]) |
|
x = self.conv(x) |
|
x = x.permute([0, 2, 3, 1]).reshape(f, -1, d) |
|
return x |
|
|
|
|
|
class MlpPixelShuffle(nn.Module): |
|
def __init__(self, dim_in, dim_out, pixelshuffle_downsample_ratio, mlp_hidden_dim=None): |
|
super().__init__() |
|
self.mm_projector_type = 'mlp_pixel_shuffle' |
|
if mlp_hidden_dim is None: |
|
self.mlp = nn.Sequential( |
|
nn.Linear(int(dim_in * (pixelshuffle_downsample_ratio ** 2)), dim_out), |
|
nn.GELU(), |
|
nn.Linear(dim_out, dim_out) |
|
) |
|
else: |
|
self.mlp = nn.Sequential( |
|
nn.Linear(int(dim_in * (pixelshuffle_downsample_ratio ** 2)), mlp_hidden_dim), |
|
nn.GELU(), |
|
nn.Linear(mlp_hidden_dim, dim_out) |
|
) |
|
self.scale_factor = pixelshuffle_downsample_ratio |
|
|
|
def pixel_shuffle(self, x, scale_factor=2): |
|
|
|
|
|
n, w, h, c = x.size() |
|
|
|
x = x.view(n, w, int(h / scale_factor), int(c * scale_factor)) |
|
|
|
x = x.permute(0, 2, 1, 3).contiguous() |
|
|
|
x = x.view(n, int(h / scale_factor), int(w / scale_factor), |
|
int(c * (scale_factor * scale_factor))) |
|
|
|
x = x.permute(0, 2, 1, 3).contiguous() |
|
|
|
return x |
|
|
|
def forward(self, x): |
|
""" |
|
Args: |
|
x (torch.Tensor): image features |
|
shape (F, v, D) |
|
Returns: |
|
shape (F, n, D) where n is token_num that has been reduced |
|
""" |
|
x = x[:, 1:, :] |
|
h = w = int(x.shape[1] ** 0.5) |
|
x = x.view(x.shape[0], h, w, -1) |
|
x = self.pixel_shuffle(x, self.scale_factor) |
|
x = self.mlp(x) |
|
x = x.view(x.shape[0],-1,x.shape[-1]) |
|
return x |
|
|
|
|
|
class OvisConvAdapter(nn.Module): |
|
def __init__(self, dim_in, dim_out, vocab_size, tokenize_function="softmax"): |
|
super().__init__() |
|
self.mm_projector_type = 'ovis_conv_adapter' |
|
self.conv = nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), padding=1) |
|
self.mlp = torch.nn.Sequential( |
|
torch.nn.Linear(dim_in, vocab_size, bias=False), |
|
torch.nn.LayerNorm(vocab_size) |
|
) |
|
self.embedding = torch.nn.Embedding(vocab_size, dim_out) |
|
self.tokenize_function = tokenize_function |
|
|
|
def tokenize(self, logits): |
|
def st_argmax(y_soft, dim): |
|
index = y_soft.max(dim, keepdim=True)[1] |
|
y_hard = torch.zeros_like(y_soft, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) |
|
ret = y_hard - y_soft.detach() + y_soft |
|
return ret |
|
|
|
if self.tokenize_function == 'softmax': |
|
tokens = torch.nn.functional.softmax(logits, dim=-1) |
|
elif self.tokenize_function == 'gumbel_argmax': |
|
tokens = torch.nn.functional.gumbel_softmax(logits, tau=self.config.tau, hard=True) |
|
elif self.tokenize_function == 'st_argmax': |
|
tokens = st_argmax(logits, dim=-1) |
|
else: |
|
raise ValueError( |
|
'Invalid `max_type`, expected softmax or gumbel_argmax or st_argmax,' |
|
f' but got {self.config.tokenize_function}' |
|
) |
|
return tokens |
|
|
|
def forward(self, x): |
|
""" |
|
Args: |
|
x (torch.Tensor): image features |
|
shape (F, v, D) |
|
Returns: |
|
shape (F, n, D) where n is token_num that has been reduced |
|
""" |
|
|
|
f, v, d = x.shape |
|
s = int(math.sqrt(v - 1)) |
|
x = x[:, 1:, :] |
|
x = x.reshape(f, s, s, d).permute([0, 3, 1, 2]) |
|
x = self.conv(x) |
|
x = x.permute([0, 2, 3, 1]).reshape(f, -1, d) |
|
|
|
|
|
logits = self.mlp(x) |
|
visual_tokens = self.tokenize(logits) |
|
|
|
|
|
out = torch.matmul(visual_tokens, self.embedding.weight) |
|
|
|
return out |
|
|