|
import torch |
|
import torch.nn as nn |
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel |
|
from transformers import PretrainedConfig |
|
|
|
siglip_config = PretrainedConfig.from_dict( |
|
{ |
|
"attention_dropout": 0.0, |
|
"hidden_act": "gelu_pytorch_tanh", |
|
"hidden_size": 1152, |
|
"image_size": 384, |
|
"intermediate_size": 4304, |
|
"layer_norm_eps": 1e-06, |
|
"model_type": "siglip_vision_model", |
|
"num_attention_heads": 16, |
|
"num_channels": 3, |
|
"num_hidden_layers": 27, |
|
"patch_size": 14, |
|
} |
|
) |
|
|
|
qwen2vl_vit_config = PretrainedConfig.from_dict( |
|
{ |
|
"depth": 32, |
|
"embed_dim": 1280, |
|
"hidden_act": "quick_gelu", |
|
"hidden_size": 3584, |
|
"in_channels": 3, |
|
"in_chans": 3, |
|
"mlp_ratio": 4, |
|
"model_type": "qwen2_vl", |
|
"num_heads": 16, |
|
"patch_size": 14, |
|
"spatial_merge_size": 2, |
|
"spatial_patch_size": 14, |
|
"temporal_patch_size": 2, |
|
"_attn_implementation": "flash_attention_2", |
|
"_attn_implementation_internal": "flash_attention_2" |
|
} |
|
) |
|
|
|
def build_vision_tower(vision_tower_cfg, **kwargs): |
|
vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)) |
|
if "siglip-so400m-patch14-384" in vision_tower: |
|
|
|
if getattr(vision_tower_cfg, "eagle_vision_tower", None) is not None: |
|
if getattr(vision_tower_cfg, "_vit_attn_implementation", None) is not None: |
|
qwen2vl_vit_config._attn_implementation = vision_tower_cfg._vit_attn_implementation |
|
qwen2vl_vit_config._attn_implementation_internal = vision_tower_cfg._vit_attn_implementation |
|
|
|
qwen2vl_vision_tower = Qwen2VisionTransformerPretrainedModel._from_config(qwen2vl_vit_config) |
|
|
|
if getattr(vision_tower_cfg, "navit_merger_hidden_dim", None) is not None: |
|
del qwen2vl_vision_tower.merger |
|
qwen2vl_vision_tower.merger = CustomPatchMerger( |
|
vision_tower_cfg.hidden_size, |
|
context_dim=1280, |
|
hidden_dim=getattr(vision_tower_cfg, "navit_merger_hidden_dim", None) |
|
) |
|
qwen2vl_vision_tower.requires_grad_(False) |
|
|
|
|
|
if getattr(vision_tower_cfg, "only_navit", False): |
|
siglip_vision_tower = None |
|
else: |
|
siglip_vision_tower = SigLipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) |
|
|
|
return siglip_vision_tower, qwen2vl_vision_tower |
|
|
|
else: |
|
siglip_vision_tower = SigLipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) |
|
return siglip_vision_tower |
|
else: |
|
raise ValueError(f"Unknown vision tower: {vision_tower}") |
|
|
|
class SigLipVisionTower(nn.Module): |
|
def __init__(self, vision_tower, args, delay_load=False, cache_dir="./cache_dir"): |
|
super().__init__() |
|
self.is_loaded = False |
|
self.image_tower_name = vision_tower |
|
self.select_layer = args.mm_vision_select_layer |
|
self.select_feature = getattr(args, "mm_vision_select_feature", "patch") |
|
self.cache_dir = cache_dir |
|
|
|
if not delay_load: |
|
self.load_model() |
|
else: |
|
from transformers import SiglipVisionModel |
|
self.cfg_only = siglip_config |
|
self.vision_tower = SiglipVisionModel._from_config(siglip_config) |
|
|
|
def load_model(self): |
|
from transformers import SiglipVisionModel |
|
self.vision_tower = SiglipVisionModel._from_config(siglip_config) |
|
self.vision_tower.requires_grad_(False) |
|
self.is_loaded = True |
|
|
|
def feature_select(self, image_forward_outs): |
|
assert self.select_feature == "cls_patch" |
|
image_features = torch.cat([image_forward_outs[:, :1, :], image_forward_outs], dim=1) |
|
return image_features |
|
|
|
def forward(self, images): |
|
if type(images) is list: |
|
image_features = [] |
|
for image in images: |
|
image_forward_out = self.vision_tower( |
|
image.to(device=self.device, dtype=self.dtype).unsqueeze(0), |
|
output_hidden_states=True, |
|
return_dict=True, |
|
) |
|
image_feature = self.feature_select(image_forward_out.last_hidden_state).to(image.dtype) |
|
image_features.append(image_feature) |
|
else: |
|
image_forward_outs = self.vision_tower( |
|
images.to(device=self.device, dtype=self.dtype), |
|
output_hidden_states=True, |
|
return_dict=True, |
|
) |
|
image_features = self.feature_select(image_forward_outs.last_hidden_state).to(images.dtype) |
|
|
|
return image_features |
|
|
|
@property |
|
def dummy_feature(self): |
|
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
|
|
|
@property |
|
def dtype(self): |
|
return self.vision_tower.dtype |
|
|
|
@property |
|
def device(self): |
|
return self.vision_tower.device |
|
|
|
@property |
|
def config(self): |
|
if self.is_loaded: |
|
return self.vision_tower.config |
|
else: |
|
return self.cfg_only |
|
|
|
@property |
|
def hidden_size(self): |
|
return self.config.hidden_size |
|
|
|
@property |
|
def num_patches(self): |
|
return (self.config.image_size // self.config.patch_size) ** 2 |
|
|
|
|
|
class CustomPatchMerger(nn.Module): |
|
def __init__(self, dim: int, context_dim: int, hidden_dim: int, spatial_merge_size: int = 2) -> None: |
|
super().__init__() |
|
self.input_dim = context_dim * (spatial_merge_size**2) |
|
self.ln_q = nn.LayerNorm(context_dim, eps=1e-6) |
|
self.mlp = nn.Sequential( |
|
nn.Linear(self.input_dim, hidden_dim), |
|
nn.GELU(), |
|
nn.Linear(hidden_dim, dim), |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.mlp(self.ln_q(x).view(-1, self.input_dim)) |
|
return x |