Valley-Eagle-7B / modeling_vision_tower.py
Hyggge's picture
feat: support modify attn_implementation of vision tower
56a47d6
import torch
import torch.nn as nn
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
from transformers import PretrainedConfig
siglip_config = PretrainedConfig.from_dict(
{
"attention_dropout": 0.0,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 384,
"intermediate_size": 4304,
"layer_norm_eps": 1e-06,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 27,
"patch_size": 14,
}
)
qwen2vl_vit_config = PretrainedConfig.from_dict(
{
"depth": 32,
"embed_dim": 1280,
"hidden_act": "quick_gelu",
"hidden_size": 3584,
"in_channels": 3,
"in_chans": 3,
"mlp_ratio": 4,
"model_type": "qwen2_vl",
"num_heads": 16,
"patch_size": 14,
"spatial_merge_size": 2,
"spatial_patch_size": 14,
"temporal_patch_size": 2,
"_attn_implementation": "flash_attention_2",
"_attn_implementation_internal": "flash_attention_2"
}
)
def build_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
if "siglip-so400m-patch14-384" in vision_tower:
# Eagle
if getattr(vision_tower_cfg, "eagle_vision_tower", None) is not None:
if getattr(vision_tower_cfg, "_vit_attn_implementation", None) is not None:
qwen2vl_vit_config._attn_implementation = vision_tower_cfg._vit_attn_implementation
qwen2vl_vit_config._attn_implementation_internal = vision_tower_cfg._vit_attn_implementation
qwen2vl_vision_tower = Qwen2VisionTransformerPretrainedModel._from_config(qwen2vl_vit_config)
if getattr(vision_tower_cfg, "navit_merger_hidden_dim", None) is not None:
del qwen2vl_vision_tower.merger
qwen2vl_vision_tower.merger = CustomPatchMerger(
vision_tower_cfg.hidden_size,
context_dim=1280,
hidden_dim=getattr(vision_tower_cfg, "navit_merger_hidden_dim", None)
) # random initialize
qwen2vl_vision_tower.requires_grad_(False)
# If only use navit, delete siglip_vision_tower
if getattr(vision_tower_cfg, "only_navit", False):
siglip_vision_tower = None
else:
siglip_vision_tower = SigLipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
return siglip_vision_tower, qwen2vl_vision_tower
# Non-Eagle
else:
siglip_vision_tower = SigLipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
return siglip_vision_tower
else:
raise ValueError(f"Unknown vision tower: {vision_tower}")
class SigLipVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False, cache_dir="./cache_dir"):
super().__init__()
self.is_loaded = False
self.image_tower_name = vision_tower
self.select_layer = args.mm_vision_select_layer
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
self.cache_dir = cache_dir
if not delay_load:
self.load_model()
else:
from transformers import SiglipVisionModel
self.cfg_only = siglip_config
self.vision_tower = SiglipVisionModel._from_config(siglip_config) # dummy-load
def load_model(self):
from transformers import SiglipVisionModel
self.vision_tower = SiglipVisionModel._from_config(siglip_config)
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def feature_select(self, image_forward_outs):
assert self.select_feature == "cls_patch"
image_features = torch.cat([image_forward_outs[:, :1, :], image_forward_outs], dim=1)
return image_features
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(
image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
output_hidden_states=True,
return_dict=True,
)
image_feature = self.feature_select(image_forward_out.last_hidden_state).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(
images.to(device=self.device, dtype=self.dtype),
output_hidden_states=True,
return_dict=True,
)
image_features = self.feature_select(image_forward_outs.last_hidden_state).to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
class CustomPatchMerger(nn.Module):
def __init__(self, dim: int, context_dim: int, hidden_dim: int, spatial_merge_size: int = 2) -> None:
super().__init__()
self.input_dim = context_dim * (spatial_merge_size**2)
self.ln_q = nn.LayerNorm(context_dim, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(self.input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.mlp(self.ln_q(x).view(-1, self.input_dim))
return x