Valley-Eagle-7B / modeling_projector.py
Hyggge's picture
feat: add modeling code
7e9d312
import math
import torch
import torch.nn as nn
def build_vision_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, 'mm_projector_type', 'linear')
if projector_type == 'conv_adapter':
return ConvAdapter(config.mm_hidden_size, config.hidden_size, getattr(config, "mlp_hidden_dim", None))
elif projector_type == 'mlp_pixel_shuffle':
return MlpPixelShuffle(config.mm_hidden_size, config.hidden_size,
config.pixelshuffle_downsample_ratio, getattr(config, "mlp_hidden_dim", None))
elif projector_type == 'ovis_conv_adapter':
return OvisConvAdapter(config.mm_hidden_size, config.hidden_size, getattr(config, "mlp_hidden_dim", 32000),
getattr(config, "tokenize_function", "softmax"))
raise ValueError(f'Unknown projector type: {projector_type}')
class ConvAdapter(nn.Module):
def __init__(self, dim_in, dim_out, mlp_hidden_dim=None):
super().__init__()
self.mm_projector_type = 'conv_adapter'
if mlp_hidden_dim is None:
self.mlp = nn.Sequential(
nn.Linear(dim_in, dim_out),
nn.GELU(),
nn.Linear(dim_out, dim_out)
)
else:
self.mlp = nn.Sequential(
nn.Linear(dim_in, mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, dim_out)
)
self.conv = nn.Conv2d(dim_out, dim_out, kernel_size=(3, 3), stride=(2, 2), padding=1)
def forward(self, x):
"""
Args:
x (torch.Tensor): image features
shape (F, v, D)
Returns:
shape (F, n, D) where n is token_num that has been reduced
"""
x = self.mlp(x)
f, v, d = x.shape
s = int(math.sqrt(v - 1))
x = x[:, 1:, :] # remove cls_token
x = x.reshape(f, s, s, d).permute([0, 3, 1, 2])
x = self.conv(x)
x = x.permute([0, 2, 3, 1]).reshape(f, -1, d)
return x
class MlpPixelShuffle(nn.Module):
def __init__(self, dim_in, dim_out, pixelshuffle_downsample_ratio, mlp_hidden_dim=None):
super().__init__()
self.mm_projector_type = 'mlp_pixel_shuffle'
if mlp_hidden_dim is None:
self.mlp = nn.Sequential(
nn.Linear(int(dim_in * (pixelshuffle_downsample_ratio ** 2)), dim_out),
nn.GELU(),
nn.Linear(dim_out, dim_out)
)
else:
self.mlp = nn.Sequential(
nn.Linear(int(dim_in * (pixelshuffle_downsample_ratio ** 2)), mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, dim_out)
)
self.scale_factor = pixelshuffle_downsample_ratio
def pixel_shuffle(self, x, scale_factor=2):
# change scale_factor from float to int
n, w, h, c = x.size()
# N, W, H, C --> N, W, H / scale, C * scale
x = x.view(n, w, int(h / scale_factor), int(c * scale_factor))
# N, W, H / scale, C * scale --> N, H / scale, W, C * scale
x = x.permute(0, 2, 1, 3).contiguous()
# N, H / scale, W, C * scale --> N, H / scale, W / scale, C * (scale ** 2)
x = x.view(n, int(h / scale_factor), int(w / scale_factor),
int(c * (scale_factor * scale_factor)))
x = x.permute(0, 2, 1, 3).contiguous()
return x
def forward(self, x):
"""
Args:
x (torch.Tensor): image features
shape (F, v, D)
Returns:
shape (F, n, D) where n is token_num that has been reduced
"""
x = x[:, 1:, :] # remove cls_token
h = w = int(x.shape[1] ** 0.5)
x = x.view(x.shape[0], h, w, -1)
x = self.pixel_shuffle(x, self.scale_factor)
x = self.mlp(x)
x = x.view(x.shape[0],-1,x.shape[-1])
return x
class OvisConvAdapter(nn.Module):
def __init__(self, dim_in, dim_out, vocab_size, tokenize_function="softmax"):
super().__init__()
self.mm_projector_type = 'ovis_conv_adapter'
self.conv = nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), padding=1)
self.mlp = torch.nn.Sequential(
torch.nn.Linear(dim_in, vocab_size, bias=False),
torch.nn.LayerNorm(vocab_size)
)
self.embedding = torch.nn.Embedding(vocab_size, dim_out)
self.tokenize_function = tokenize_function
def tokenize(self, logits):
def st_argmax(y_soft, dim): # straight-through softmax
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(y_soft, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
return ret
if self.tokenize_function == 'softmax':
tokens = torch.nn.functional.softmax(logits, dim=-1)
elif self.tokenize_function == 'gumbel_argmax':
tokens = torch.nn.functional.gumbel_softmax(logits, tau=self.config.tau, hard=True)
elif self.tokenize_function == 'st_argmax':
tokens = st_argmax(logits, dim=-1)
else:
raise ValueError(
'Invalid `max_type`, expected softmax or gumbel_argmax or st_argmax,'
f' but got {self.config.tokenize_function}'
)
return tokens
def forward(self, x):
"""
Args:
x (torch.Tensor): image features
shape (F, v, D)
Returns:
shape (F, n, D) where n is token_num that has been reduced
"""
# conv
f, v, d = x.shape
s = int(math.sqrt(v - 1))
x = x[:, 1:, :] # remove cls_token
x = x.reshape(f, s, s, d).permute([0, 3, 1, 2])
x = self.conv(x)
x = x.permute([0, 2, 3, 1]).reshape(f, -1, d)
# tokenize
logits = self.mlp(x)
visual_tokens = self.tokenize(logits)
# get embeddings
out = torch.matmul(visual_tokens, self.embedding.weight)
return out